def evaluate_multiple_clients( federated_data: FederatedData, client_ids: List[str], model: Model, params: Params, client_data_hparams: Optional[dataset_util.ClientDataHParams] = None ) -> Iterator[MetricResults]: """Evaluates model over input clients' datasets. Args: federated_data: Federated data separated per client. client_ids: Ids of clients to evaluate. model: Model implementation. params: Pytree of model parameters to be evaluated. client_data_hparams: Hyperparameters for client dataset preparation. Yields: Ordered mapping of metrics per client or empty mapping if a client's dataset is emtpy. """ if client_data_hparams is not None: federated_data = federated_data.preprocess( lambda ds: dataset_util.preprocess_tf_dataset( ds, client_data_hparams)) for _, client_dataset in prefetch.PrefetchClientDatasetsIterator( federated_data, client_ids): yield evaluate_single_client(client_dataset, model, params)
def test_evaluate_single_client(self): num_clients = 10 num_classes = 10 num_examples = 100 client_data_hparams = dataset_util.ClientDataHParams(batch_size=20, num_epochs=3) data, model = test_util.create_toy_example(num_clients=num_clients, num_clusters=4, num_classes=num_classes, num_examples=num_examples, seed=0) rng_seq = hk.PRNGSequence(0) init_params = model.init_params(next(rng_seq)) dataset = dataset_util.preprocess_tf_dataset( dataset_util.create_tf_dataset_for_clients(data), client_data_hparams) with self.subTest('tf dataset'): init_metrics = evaluation_util.evaluate_single_client( dataset=dataset, model=model, params=init_params) self.assertLess(0.0, init_metrics['loss']) with self.subTest('plain iterator'): init_metrics = evaluation_util.evaluate_single_client( dataset=dataset.as_numpy_iterator(), model=model, params=init_params) self.assertLess(0.0, init_metrics['loss'])
def test_preprocess_tf_dataset(self, hparams, expected_num_batches): x = np.arange(10 * 2).reshape((10, 2)) numpy_data = collections.OrderedDict(x=x, y=x, z=x) dataset = tf.data.Dataset.from_tensor_slices(numpy_data) batches = list(dataset_util.preprocess_tf_dataset(dataset, hparams)) self.assertLen(batches, expected_num_batches)
def train_multiple_clients( federated_data: FederatedData, client_ids: List[str], client_trainer: ClientTrainer[T], init_client_trainer_state: T, rng_seq: PRNGSequence, client_data_hparams: dataset_util.ClientDataHParams) -> Iterator[Any]: """Trains separate model for each client and records client updates. Depending on the value of `--fedjax_experimental_disable_parallel`, this will be run either sequentially or in parallel across local devices via `jax.pmap`. Args: federated_data: Federated data separated per client. client_ids: Ids of clients to train. client_trainer: ClientTrainer instance. init_client_trainer_state: Initial client trainer state. This will typically be derived from algorithm state before calling `train_multiple_clients`. rng_seq: Random key generator. client_data_hparams: Hyperparameters for client dataset preparation. Yields: Output of client trainer that is typically just an updated version of the input `init_client_trainer_state`. However, output is flexible. """ should_run_parallel = ( FLAGS.fedjax_experimental_disable_parallel == 'false' or (FLAGS.fedjax_experimental_disable_parallel == 'auto' and jax.local_device_count() > 1)) if should_run_parallel: yield from _train_multiple_clients_parallel( federated_data=federated_data, client_ids=client_ids, client_trainer=client_trainer, init_client_trainer_state=init_client_trainer_state, rng_seq=rng_seq, client_data_hparams=client_data_hparams) return preprocessed_federated_data = federated_data.preprocess( lambda ds: dataset_util.preprocess_tf_dataset(ds, client_data_hparams)) for _, client_dataset in prefetch.PrefetchClientDatasetsIterator( preprocessed_federated_data, client_ids): examples = zip(dataset_util.iterate(client_dataset), rng_seq) client_trainer_state = client_trainer.loop(init_client_trainer_state, examples) yield client_trainer_state
def _train_multiple_clients_parallel( federated_data: FederatedData, client_ids: List[str], client_trainer: ClientTrainer[T], init_client_trainer_state: T, rng_seq: PRNGSequence, client_data_hparams: dataset_util.ClientDataHParams) -> Iterator[Any]: """Trains separate model for each client and records client updates. It has the same inputs and return values as `train_multiple_clients` above, but parallelizes the training across multiple devices if available. Args: federated_data: Federated data separated per client. client_ids: Ids of clients to train. client_trainer: ClientTrainer instance. init_client_trainer_state: Initial client trainer state. This will typically be derived from algorithm state before calling `train_multiple_clients`. rng_seq: Random key generator. client_data_hparams: Hyperparameters for client dataset preparation. The `drop_remainder` field is automatically set to True to ensure that all batches have the same batch size. Yields: Output of client trainer that is typically just an updated version of the input `init_client_trainer_state`. However, output is flexible. """ client_data_hparams = client_data_hparams._replace(drop_remainder=True) client_data = [] preprocessed_federated_data = federated_data.preprocess( lambda ds: dataset_util.preprocess_tf_dataset(ds, client_data_hparams)) for _, client_dataset in prefetch.PrefetchClientDatasetsIterator( preprocessed_federated_data, client_ids): client_data.append(list(dataset_util.iterate(client_dataset))) # Sort by length to group similarly sized datasets together to minimize the # amount of fillvalues needed. client_data = sorted(client_data, key=len) # TODO(b/177346980): Handle case where all clients' data is empty by padding. fillvalue = _get_fillvalue(client_data) num_local_devices = jax.local_device_count() init_stack_state = tree_util.tree_broadcast( init_client_trainer_state, axis_size=num_local_devices) stack_rng = jrandom.split(next(rng_seq), num_local_devices) quotient, remainder = divmod(len(client_ids), num_local_devices) num_iterations = quotient + bool(remainder) for i in range(num_iterations): stack_state = init_stack_state streams = client_data[num_local_devices * i:num_local_devices * (i + 1)] # Handle number of clients not divisible by num_devices. if len(streams) < num_local_devices: client_count = len(streams) streams.extend( [[fillvalue] for _ in range(num_local_devices - client_count)]) else: client_count = num_local_devices for batches in itertools.zip_longest(*streams, fillvalue=fillvalue): mask = jnp.array([b is not fillvalue for b in batches]) stack_mask = tree_util.tree_stack(mask) stack_batch = tree_util.tree_stack(batches) # Mask must be the first input. stack_state, stack_rng = _pmap_step(stack_mask, client_trainer, stack_state, stack_batch, stack_rng) # Unstack stack_state, yield each one. for j, final_state in enumerate(tree_util.tree_unstack(stack_state)): if j < client_count: yield final_state