def test_simple_training(self): it_process = flars_fedavg.build_federated_averaging_process( _keras_model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1)) server_state = it_process.initialize() # Test out manually setting weights: keras_model = tff.simulation.models.mnist.create_keras_model( compile_model=True) @tf.function def deterministic_batch(): return collections.OrderedDict( x=np.ones([1, 784], dtype=np.float32), y=np.ones([1, 1], dtype=np.int64)) batch = deterministic_batch() federated_data = [[batch]] def keras_evaluate(state): tff.learning.assign_weights_to_keras_model(keras_model, state.model) # N.B. The loss computed here won't match the loss computed by TFF because # of the Dropout layer. keras_model.test_on_batch(**batch) loss_list = [] for _ in range(3): keras_evaluate(server_state) server_state, output = it_process.next(server_state, federated_data) loss_list.append(output.loss) keras_evaluate(server_state) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_construction(self): it_process = flars_fedavg.build_federated_averaging_process( _keras_model_fn) self.assertIsInstance(it_process, tff.utils.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual(str(federated_data_type), '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
def test_simple_training(self): it_process = flars_fedavg.build_federated_averaging_process( _keras_model_fn) server_state = it_process.initialize() Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name # Test out manually setting weights: keras_model = tff.simulation.models.mnist.create_keras_model( compile_model=True) def deterministic_batch(): return Batch(x=np.ones([1, 784], dtype=np.float32), y=np.ones([1, 1], dtype=np.int64)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] def keras_evaluate(state): tff.learning.assign_weights_to_keras_model(keras_model, state.model) # N.B. The loss computed here won't match the # loss computed by TFF because of the Dropout layer. keras_model.test_on_batch(batch.x, batch.y) loss_list = [] for _ in range(3): keras_evaluate(server_state) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) keras_evaluate(server_state) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_construction(self): it_process = flars_fedavg.build_federated_averaging_process( _keras_model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1)) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual( str(federated_data_type), '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
def test_self_contained_example_keras_model(self): client_data = create_client_data() train_data = [client_data()] trainer = flars_fedavg.build_federated_averaging_process( _keras_model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1)) state = trainer.initialize() losses = [] for _ in range(2): state, outputs = trainer.next(state, train_data) # Track the loss. losses.append(outputs.loss) self.assertLess(losses[1], losses[0])
def test_self_contained_example_keras_model(self): def model_fn(): return tff.learning.from_compiled_keras_model( tff.simulation.models.mnist.create_simple_keras_model(), sample_batch) client_data = create_client_data() train_data = [client_data()] sample_batch = self.evaluate(next(iter(train_data[0]))) trainer = flars_fedavg.build_federated_averaging_process(model_fn) state = trainer.initialize() losses = [] for _ in range(2): state, outputs = trainer.next(state, train_data) # Track the loss. losses.append(outputs.loss) self.assertLess(losses[1], losses[0])
def _federated_averaging_training_loop(model_fn, client_optimizer_fn, server_optimizer_fn, client_datasets_fn, evaluate_fn, total_rounds=500, rounds_per_eval=1, metrics_hook=None): """A simple example of training loop for the Federated Averaging algorithm. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`. client_datasets_fn: A function that takes the round number, and returns a list of `tf.data.Datset`, one per client. evaluate_fn: A function that takes state, performs evaluation, and returns evaluations metrics. total_rounds: Number of rounds to train. rounds_per_eval: How often to call the `metrics_hook` function. metrics_hook: A function taking arguments (training metrics, evaluation metrics, and round number). Optional. Returns: Final `ServerState`. """ logging.info('Starting federated training loop') checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name) checkpoint_manager_obj = checkpoint_manager.FileCheckpointManager( checkpoint_dir) if FLAGS.server_optimizer != 'flars': logging.error('Unsupported server_optimzier: %s', FLAGS.server_optimizer) else: iterative_process = flars_fedavg.build_federated_averaging_process( model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) ServerState = flars_fedavg.ServerState # pylint: disable=invalid-name # construct an initial state here to act as a checkpoint template inital_state = iterative_process.initialize() inital_state = ServerState.from_tff_result(inital_state) logging.info('Looking for checkpoints in \'%s\'', checkpoint_dir) state, round_num = checkpoint_manager_obj.load_latest_checkpoint( inital_state) if state is None: logging.info('No previous checkpoints, initializing experiment') state = inital_state round_num = 0 if metrics_hook is not None: eval_metrics = evaluate_fn(state) metrics_hook({}, eval_metrics, round_num) checkpoint_manager_obj.save_checkpoint(state, 0) else: logging.info('Restarted from checkpoint round %d', round_num) while round_num < total_rounds: round_num += 1 train_metrics = {} # Reset the executor to clear the cache, and clear the default graph to # garbage collect tf.Functions that will no longer be used. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) tf.compat.v1.reset_default_graph() round_start_time = time.time() data_prep_start_time = time.time() train_data = client_datasets_fn(round_num) train_metrics['prepare_datasets_secs'] = time.time( ) - data_prep_start_time training_start_time = time.time() state, tff_train_metrics = iterative_process.next(state, train_data) state = ServerState.from_tff_result(state) tff_train_metrics = tff_train_metrics._asdict(recursive=True) train_metrics.update(tff_train_metrics) train_metrics['training_secs'] = time.time() - training_start_time logging.info('Round {:2d} elapsed time: {:.2f}s .'.format( round_num, (time.time() - round_start_time))) train_metrics['total_round_secs'] = time.time() - round_start_time if (round_num % FLAGS.rounds_per_checkpoint == 0 or round_num == total_rounds): save_checkpoint_start_time = time.time() checkpoint_manager_obj.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) if round_num % rounds_per_eval == 0 or round_num == total_rounds: if metrics_hook is not None: eval_metrics = evaluate_fn(state) metrics_hook(train_metrics, eval_metrics, round_num)
def federated_averaging_training_loop(model_fn, server_optimizer_fn, client_datasets_fn, total_rounds=500, rounds_per_eval=1, metrics_hook=lambda *args: None): """A simple example of training loop for the Federated Averaging algorithm. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`. client_datasets_fn: A function that takes the round number, and returns a list of `tf.data.Datset`, one per client. total_rounds: Number of rounds to train. rounds_per_eval: How often to call the `metrics_hook` function. metrics_hook: A function taking arguments (server_state, train_metrics, round_num) and performs evaluation. Optional. Returns: Final `ServerState`. """ logging.info('Starting federated_training_loop') checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name) if FLAGS.server_optimizer != 'flars': iterative_process = simple_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimizer_fn) ServerState = simple_fedavg.ServerState # pylint: disable=invalid-name else: iterative_process = flars_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimizer_fn) ServerState = flars_fedavg.ServerState # pylint: disable=invalid-name server_state = ServerState.from_anon_tuple(iterative_process.initialize()) round_num = None train_metrics = {} latest_checkpoint_dir = checkpoint_utils.latest_checkpoint( checkpoint_dir, CHECKPOINT_PREFIX) logging.info('Looking for checkpoints in [%s/%s]', checkpoint_dir, CHECKPOINT_PREFIX) while latest_checkpoint_dir is not None: # Restart from a previous round. logging.info('Loading a previous checkpoint') try: server_state, metrics_hook.results, round_num = read_checkpoint( latest_checkpoint_dir, server_state) break except OSError as e: # Likely corrupted checkpoint, possibly job died while writing. Delete the # checkpoint directory and try again. logging.error('Exception [%s]', e) logging.warning('Deleteing likely corrupted checkpoint at [%s]', latest_checkpoint_dir) tf.io.gfile.rmtree(latest_checkpoint_dir) latest_checkpoint_dir = checkpoint_utils.latest_checkpoint( checkpoint_dir, CHECKPOINT_PREFIX) if round_num is not None: logging.info('Restarted from checkpoint round %d', round_num) else: # Write the initial checkpoint logging.info('No previous checkpoints, initializing experiment') round_num = 0 metrics_hook(server_state, train_metrics, round_num) write_checkpoint(checkpoint_dir, server_state, metrics_hook.results, round_num) while round_num < total_rounds: round_num += 1 # Reset the executor to clear the cache, and clear the default graph to # garbage collect tf.Functions that will no longer be used. tff.framework.set_default_executor( tff.framework.create_local_executor(max_fanout=25)) tf.compat.v1.reset_default_graph() round_start_time = time.time() data_prep_start_time = time.time() federated_train_data = client_datasets_fn(round_num) train_metrics['prepare_datasets_secs'] = time.time( ) - data_prep_start_time training_start_time = time.time() anon_tuple_server_state, tff_train_metrics = iterative_process.next( server_state, federated_train_data) server_state = ServerState.from_anon_tuple(anon_tuple_server_state) train_metrics.update(tff_train_metrics._asdict(recursive=True)) train_metrics['training_secs'] = time.time() - training_start_time logging.info('Round {:2d} elapsed time: {:.2f}s .'.format( round_num, (time.time() - round_start_time))) train_metrics['total_round_secs'] = time.time() - round_start_time if round_num % FLAGS.rounds_per_checkpoint == 0: write_checkpoint_start_time = time.time() write_checkpoint(checkpoint_dir, server_state, metrics_hook.results, round_num) train_metrics['write_checkpoint_secs'] = ( time.time() - write_checkpoint_start_time) if round_num % rounds_per_eval == 0: metrics_hook(server_state, train_metrics, round_num) metrics_hook(server_state, train_metrics, total_rounds) write_checkpoint(checkpoint_dir, server_state, metrics_hook.results, round_num) return server_state