def fit_round(self, rnd: int) -> Optional[Weights]: """Perform a single round of federated averaging.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_fit( rnd=rnd, weights=self.weights, client_manager=self._client_manager ) log( DEBUG, "fit_round: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) if not client_instructions: log(INFO, "fit_round: no clients sampled, cancel fit") return None # Collect training results from all clients participating in this round results, failures = fit_clients(client_instructions) log( DEBUG, "fit_round received %s results and %s failures", len(results), len(failures), ) # Aggregate training results return self.strategy.aggregate_fit(rnd, results, failures)
def start(self) -> None: """Start the instance.""" instance_groups = group_instances_by_specs(self.instances) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [ executor.submit(create_instances, self.adapter, instance_group, self.timeout) for instance_group in instance_groups ] concurrent.futures.wait(futures) try: for future in futures: future.result() # pylint: disable=broad-except except Exception as exc: log( ERROR, "Failed to start the cluster completely. Shutting down...", ) log(ERROR, exc) for future in futures: future.cancel() self.terminate() raise StartFailed() for ins in self.instances: log(DEBUG, ins)
def on_configure_fit( self, rnd: int, weights: Weights, client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" # Block until `min_num_clients` are available sample_size, min_num_clients = self.num_fit_clients( client_manager.num_available()) success = client_manager.wait_for(num_clients=min_num_clients, timeout=WAIT_TIMEOUT) if not success: # Do not continue if not enough clients are available log( INFO, "FedFS: not enough clients available after timeout %s", WAIT_TIMEOUT, ) return [] # Sample clients clients = self._contribution_based_sampling( sample_size=sample_size, client_manager=client_manager) # Prepare parameters and config parameters = weights_to_parameters(weights) config = {} if self.on_fit_config_fn is not None: # Use custom fit config function if provided config = self.on_fit_config_fn(rnd) # Set timeout for this round use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow) config["timeout"] = str( self.t_fast if use_fast_timeout else self.t_slow) # Fit instructions fit_ins = FitIns(parameters, config) # Return client/config pairs return [(client, fit_ins) for client in clients]
def evaluate( self, rnd: int ) -> Optional[Tuple[Optional[float], EvaluateResultsAndFailures]]: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_evaluate( rnd=rnd, weights=self.weights, client_manager=self._client_manager ) if not client_instructions: log(INFO, "evaluate: no clients sampled, cancel federated evaluation") return None log( DEBUG, "evaluate: strategy sampled %s clients", len(client_instructions), ) # Evaluate current global weights on those clients results_and_failures = evaluate_clients(client_instructions) results, failures = results_and_failures log( DEBUG, "evaluate received %s results and %s failures", len(results), len(failures), ) # Aggregate the evaluation results loss_aggregated = self.strategy.aggregate_evaluate(rnd, results, failures) return loss_aggregated, results_and_failures
def upload_all(self, local_path: str, remote_path: str) -> Dict[str, SFTPAttributes]: """Upload file to all instances.""" results: Dict[str, SFTPAttributes] = {} with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Start the load operations and mark each future with its URL future_to_result = { executor.submit(self.upload, instance_name, local_path, remote_path): instance_name for instance_name in self.get_instance_names() } for future in concurrent.futures.as_completed(future_to_result): instance_name = future_to_result[future] try: results[instance_name] = future.result() # pylint: disable=broad-except except Exception as exc: log(ERROR, (instance_name, exc)) return results
def evaluate(self, ins: fl.common.EvaluateIns) -> fl.common.EvaluateRes: weights = fl.common.parameters_to_weights(ins.parameters) config = ins.config log( DEBUG, "evaluate on %s (examples: %s), config %s", self.cid, self.num_examples_test, config, ) # Use provided weights to update the local model self.model.set_weights(weights) # Evaluate the updated model on the local dataset loss, acc = keras_evaluate(self.model, self.ds_test, batch_size=self.num_examples_test) # Return the number of evaluation examples and the evaluation result (loss) return fl.common.EvaluateRes(num_examples=self.num_examples_test, loss=loss, accuracy=acc)
def start_client( server_address: str, client: Client, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, ) -> None: """Start a Flower Client which connects to a gRPC server. Arguments: server_address: str. The IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. client: flwr.client.Client. An implementation of the abstract base class `flwr.client.Client`. grpc_max_message_length: int (default: 536_870_912, this equals 512MB). The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train very large models might need to increase this value. Note that the Flower server needs to be started with the same value (see `flwr.server.start_server`), otherwise it will not know about the increased limit and block larger messages. Returns: None. """ while True: sleep_duration: int = 0 with insecure_grpc_connection( server_address, max_message_length=grpc_max_message_length ) as conn: receive, send = conn log(INFO, "Opened (insecure) gRPC connection") while True: server_message = receive() client_message, sleep_duration, keep_going = handle( client, server_message ) send(client_message) if not keep_going: break if sleep_duration == 0: log(INFO, "Disconnect and shut down") break # Sleep and reconnect afterwards log( INFO, "Disconnect, then re-establish connection after %s second(s)", sleep_duration, ) time.sleep(sleep_duration)
def _get_initial_parameters(self) -> Parameters: """Get initial parameters from one of the available clients.""" # Server-side parameter initialization parameters: Optional[Parameters] = self.strategy.initialize_parameters( client_manager=self._client_manager) if parameters is not None: log(INFO, "Using initial parameters provided by strategy") return parameters # Get initial parameters from one of the clients log(INFO, "Requesting initial parameters from one random client") random_client = self._client_manager.sample(1)[0] parameters_res = random_client.get_parameters() log(INFO, "Received initial parameters from one random client") return parameters_res.parameters
def run(baseline: str, setting: str, adapter: str) -> None: """Run baseline.""" print(f"Starting baseline with {setting} settings.") wheel_remote_path = (f"/root/{WHEEL_FILENAME}" if adapter == "docker" else f"/home/ubuntu/{WHEEL_FILENAME}") settings = load_baseline_setting(baseline, setting) # Get instances and add a logserver to the list instances = settings.instances instances.append( Instance(name="logserver", group="logserver", num_cpu=2, num_ram=2)) # Configure cluster log(INFO, "(1/9) Configure cluster.") cluster = configure_cluster(adapter, instances, baseline, setting) # Start the cluster; this takes some time log(INFO, "(2/9) Start cluster.") cluster.start() # Upload wheel to all instances log(INFO, "(3/9) Upload wheel to all instances.") cluster.upload_all(WHEEL_LOCAL_PATH, wheel_remote_path) # Install the wheel on all instances log(INFO, "(4/9) Install wheel on all instances.") cluster.exec_all(command.install_wheel(wheel_remote_path)) extras = ["examples-tensorflow" ] if "tf_" in baseline else ["examples-pytorch"] cluster.exec_all( command.install_wheel(wheel_remote_path=wheel_remote_path, wheel_extras=extras)) # Download datasets in server and clients log(INFO, "(5/9) Download dataset on server and clients.") cluster.exec_all(command.download_dataset(baseline=baseline), groups=["server", "clients"]) # Start logserver log(INFO, "(6/9) Start logserver.") logserver = cluster.get_instance("logserver") cluster.exec( logserver.name, command.start_logserver( logserver_s3_bucket=CONFIG.get("aws", "logserver_s3_bucket"), logserver_s3_key=f"{baseline}_{setting}_{now()}.log", ), ) # Start Flower server on Flower server instances log(INFO, "(7/9) Start server.") cluster.exec( "server", command.start_server( log_host=f"{logserver.private_ip}:8081", baseline=baseline, setting=setting, ), ) # Start Flower clients log(INFO, "(8/9) Start clients.") server = cluster.get_instance("server") with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Start the load operations and mark each future with its URL concurrent.futures.wait([ executor.submit( cluster.exec, client_setting.instance_name, command.start_client( log_host=f"{logserver.private_ip}:8081", server_address=f"{server.private_ip}:8080", baseline=baseline, setting=setting, cid=client_setting.cid, ), ) for client_setting in settings.clients ]) # Shutdown server and client instance after 10min if not at least one Flower # process is running it log(INFO, "(9/9) Start shutdown watcher script.") cluster.exec_all(command.watch_and_shutdown("flwr", adapter)) # Give user info how to tail logfile private_key = (DOCKER_PRIVATE_KEY if adapter == "docker" else path.expanduser(CONFIG.get("ssh", "private_key"))) log( INFO, "If you would like to tail the central logfile run:\n\n\t%s\n", command.tail_logfile(adapter, private_key, logserver), )
def start_server( # pylint: disable=too-many-arguments server_address: str = DEFAULT_SERVER_ADDRESS, server: Optional[Server] = None, config: Optional[Dict[str, int]] = None, strategy: Optional[Strategy] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, force_final_distributed_eval: bool = False, certificates: Optional[Tuple[bytes, bytes, bytes]] = None, ) -> History: """Start a Flower server using the gRPC transport layer. Arguments --------- server_address: Optional[str] (default: `"[::]:8080"`). The IPv6 address of the server. server: Optional[flwr.server.Server] (default: None). An implementation of the abstract base class `flwr.server.Server`. If no instance is provided, then `start_server` will create one. config: Optional[Dict[str, int]] (default: None). The only currently supported values is `num_rounds`, so a full configuration object instructing the server to perform three rounds of federated learning looks like the following: `{"num_rounds": 3}`. strategy: Optional[flwr.server.Strategy] (default: None). An implementation of the abstract base class `flwr.server.Strategy`. If no strategy is provided, then `start_server` will use `flwr.server.strategy.FedAvg`. grpc_max_message_length: int (default: 536_870_912, this equals 512MB). The maximum length of gRPC messages that can be exchanged with the Flower clients. The default should be sufficient for most models. Users who train very large models might need to increase this value. Note that the Flower clients need to be started with the same value (see `flwr.client.start_client`), otherwise clients will not know about the increased limit and block larger messages. force_final_distributed_eval: bool (default: False). Forces a distributed evaluation to occur after the last training epoch when enabled. certificates : Tuple[bytes, bytes, bytes] (default: None) Tuple containing root certificate, server certificate, and private key to start a secure SSL-enabled server. The tuple is expected to have three bytes elements in the following order: * CA certificate. * server certificate. * server private key. Returns ------- hist: flwr.server.history.History. Object containing metrics from training. Examples -------- Starting an insecure server: >>> start_server() Starting a SSL-enabled server: >>> start_server( >>> certificates=( >>> Path("/crts/root.pem").read_bytes(), >>> Path("/crts/localhost.crt").read_bytes(), >>> Path("/crts/localhost.key").read_bytes() >>> ) >>> ) """ initialized_server, initialized_config = _init_defaults( server, config, strategy) # Start gRPC server grpc_server = start_grpc_server( client_manager=initialized_server.client_manager(), server_address=server_address, max_message_length=grpc_max_message_length, certificates=certificates, ) num_rounds = initialized_config["num_rounds"] ssl_status = "enabled" if certificates is not None else "disabled" msg = f"Flower server running ({num_rounds} rounds)\nSSL is {ssl_status}" log(INFO, msg) hist = _fl( server=initialized_server, config=initialized_config, force_final_distributed_eval=force_final_distributed_eval, ) # Stop the gRPC server grpc_server.stop(grace=1) return hist
def custom_fit( model: tf.keras.Model, dataset: tf.data.Dataset, num_epochs: int, batch_size: int, callbacks: List[tf.keras.callbacks.Callback], delay_factor: float = 0.0, timeout: Optional[int] = None, ) -> Tuple[bool, float, int]: """Train the model using a custom training loop.""" ds_train = dataset.batch(batch_size=batch_size, drop_remainder=False) # Keep results for plotting train_loss_results = [] train_accuracy_results = [] # Optimizer optimizer = tf.keras.optimizers.Adam() fit_begin = timeit.default_timer() num_examples = 0 for epoch in range(num_epochs): log(INFO, "Starting epoch %s", epoch) epoch_loss_avg = tf.keras.metrics.Mean() epoch_accuracy = tf.keras.metrics.CategoricalAccuracy() # Single loop over the dataset batch_begin = timeit.default_timer() num_examples_batch = 0 for batch, (x, y) in enumerate(ds_train): num_examples_batch += len(x) # Optimize the model loss_value, grads = grad(model, x, y) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # Track progress epoch_loss_avg.update_state( loss_value) # Add the current batch loss epoch_accuracy.update_state(y, model(x, training=True)) # Track the number of examples used for training num_examples += x.shape[0] # Delay batch_duration = timeit.default_timer() - batch_begin if delay_factor > 0.0: time.sleep(batch_duration * delay_factor) # Progress log if batch % 100 == 0: log( INFO, "Batch %s: loss %s (%s examples processed, batch duration: %s)", batch, loss_value, num_examples_batch, batch_duration, ) # Timeout if timeout is not None: fit_duration = timeit.default_timer() - fit_begin if fit_duration > timeout: log(INFO, "client timeout") return (False, fit_duration, num_examples) batch_begin = timeit.default_timer() # End epoch train_loss_results.append(epoch_loss_avg.result()) train_accuracy_results.append(epoch_accuracy.result()) log( INFO, "Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format( epoch, epoch_loss_avg.result(), epoch_accuracy.result()), ) fit_duration = timeit.default_timer() - fit_begin return True, fit_duration, num_examples
def fit(self, num_rounds: int) -> History: """Run federated averaging for a number of rounds.""" history = History() # Initialize weights by asking one client to return theirs self.weights = self._get_initial_weights() res = self.strategy.evaluate(weights=self.weights) if res is not None: log( INFO, "initial weights (loss/accuracy): %s, %s", res[0], res[1], ) history.add_loss_centralized(rnd=0, loss=res[0]) history.add_accuracy_centralized(rnd=0, acc=res[1]) # Run federated learning for num_rounds log(INFO, "[TIME] FL starting") start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): # Train model and replace previous global model weights_prime = self.fit_round(rnd=current_round) if weights_prime is not None: self.weights = weights_prime # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(weights=self.weights) if res_cen is not None: loss_cen, acc_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", current_round, loss_cen, acc_cen, timeit.default_timer() - start_time, ) history.add_loss_centralized(rnd=current_round, loss=loss_cen) history.add_accuracy_centralized(rnd=current_round, acc=acc_cen) # Evaluate model on a sample of available clients res_fed = self.evaluate(rnd=current_round) if res_fed is not None and res_fed[0] is not None: loss_fed, _ = res_fed history.add_loss_distributed(rnd=current_round, loss=cast(float, loss_fed)) # Conclude round loss = res_cen[0] if res_cen is not None else None acc = res_cen[1] if res_cen is not None else None should_continue = self.strategy.on_conclude_round( current_round, loss, acc) if not should_continue: break # Send shutdown signal to all clients all_clients = self._client_manager.all() _ = shutdown(clients=[all_clients[k] for k in all_clients.keys()]) # Bookkeeping end_time = timeit.default_timer() elapsed = end_time - start_time log(INFO, "[TIME] FL finished in %s", elapsed) return history
def start_server( server_address: str = DEFAULT_SERVER_ADDRESS, server: Optional[Server] = None, config: Optional[Dict[str, int]] = None, strategy: Optional[Strategy] = None, ) -> None: """Start a Flower server using the gRPC transport layer.""" # Create server instance if none was given if server is None: client_manager = SimpleClientManager() if strategy is None: strategy = FedAvg() server = Server(client_manager=client_manager, strategy=strategy) # Set default config values if config is None: config = {} if "num_rounds" not in config: config["num_rounds"] = 1 # Start gRPC server grpc_server = start_insecure_grpc_server( client_manager=server.client_manager(), server_address=server_address) log(INFO, "Flower server running (insecure, %s rounds)", config["num_rounds"]) # Fit model hist = server.fit(num_rounds=config["num_rounds"]) log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) log(INFO, "app_fit: accuracies_distributed %s", str(hist.accuracies_distributed)) log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) log(INFO, "app_fit: accuracies_centralized %s", str(hist.accuracies_centralized)) # Temporary workaround to force distributed evaluation server.strategy.eval_fn = None # type: ignore # Evaluate the final trained model res = server.evaluate(rnd=-1) if res is not None: loss, (results, failures) = res log(INFO, "app_evaluate: federated loss: %s", str(loss)) log( INFO, "app_evaluate: results %s", str([(res[0].cid, res[1]) for res in results]), ) log(INFO, "app_evaluate: failures %s", str(failures)) else: log(INFO, "app_evaluate: no evaluation result") # Stop the gRPC server grpc_server.stop(1)
def fit(self, num_rounds: int) -> History: """Run federated averaging for a number of rounds.""" history = History() # Initialize weights by asking one client to return theirs log(INFO, "Getting initial parameters") self.weights = self._get_initial_weights() log(INFO, "Evaluating initial parameters") res = self.strategy.evaluate(weights=self.weights) if res is not None: log( INFO, "initial weights (loss/accuracy): %s, %s", res[0], res[1], ) history.add_loss_centralized(rnd=0, loss=res[0]) history.add_accuracy_centralized(rnd=0, acc=res[1]) # Run federated learning for num_rounds log(INFO, "[TIME] FL starting") start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): # Train model and replace previous global model weights_prime = self.fit_round(rnd=current_round) if weights_prime is not None: self.weights = weights_prime # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(weights=self.weights) if res_cen is not None: loss_cen, acc_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", current_round, loss_cen, acc_cen, timeit.default_timer() - start_time, ) history.add_loss_centralized(rnd=current_round, loss=loss_cen) history.add_accuracy_centralized(rnd=current_round, acc=acc_cen) # Evaluate model on a sample of available clients res_fed = self.evaluate(rnd=current_round) if res_fed is not None and res_fed[0] is not None: loss_fed, _ = res_fed history.add_loss_distributed( rnd=current_round, loss=cast(float, loss_fed) ) # Bookkeeping end_time = timeit.default_timer() elapsed = end_time - start_time log(INFO, "[TIME] FL finished in %s", elapsed) return history
def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity)
def configure_fit( self, rnd: int, weights: Weights, client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" # Block until `min_num_clients` are available sample_size, min_num_clients = self.num_fit_clients( client_manager.num_available()) success = client_manager.wait_for(num_clients=min_num_clients, timeout=WAIT_TIMEOUT) if not success: # Do not continue if not enough clients are available log( INFO, "FedFS: not enough clients available after timeout %s", WAIT_TIMEOUT, ) return [] # Sample clients msg = "FedFS round %s, sample %s clients (based on all previous contributions)" if self.alternating_timeout: log( DEBUG, msg, str(rnd), str(sample_size), ) clients = self._contribution_based_sampling( sample_size=sample_size, client_manager=client_manager) elif self.importance_sampling: if rnd == 1: # Sample with 1/k in the first round log( DEBUG, "FedFS round %s, sample %s clients with 1/k", str(rnd), str(sample_size), ) clients = self._one_over_k_sampling( sample_size=sample_size, client_manager=client_manager) else: fast_round = is_fast_round(rnd - 1, r_fast=self.r_fast, r_slow=self.r_slow) log( DEBUG, "FedFS round %s, sample %s clients, fast_round %s", str(rnd), str(sample_size), str(fast_round), ) clients = self._fs_based_sampling( sample_size=sample_size, client_manager=client_manager, fast_round=fast_round, ) else: clients = self._one_over_k_sampling(sample_size=sample_size, client_manager=client_manager) # Prepare parameters and config parameters = weights_to_parameters(weights) config = {} if self.on_fit_config_fn is not None: # Use custom fit config function if provided config = self.on_fit_config_fn(rnd) # Set timeout for this round if self.dynamic_timeout: if self.durations: candidates = timeout_candidates( durations=self.durations, max_timeout=self.t_slow, ) timeout = next_timeout( candidates=candidates, percentile=self.dynamic_timeout_percentile, ) config["timeout"] = str(timeout) else: # Initial round has not past durations, use max_timeout config["timeout"] = str(self.t_slow) elif self.alternating_timeout: use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow) config["timeout"] = str( self.t_fast if use_fast_timeout else self.t_slow) else: config["timeout"] = str(self.t_slow) # Fit instructions fit_ins = FitIns(parameters, config) # Return client/config pairs return [(client, fit_ins) for client in clients]
def __init__( self, fraction_fit: float = 0.1, fraction_eval: float = 0.1, min_fit_clients: int = 2, min_eval_clients: int = 2, min_available_clients: int = 2, eval_fn: Optional[ Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]] ] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, dummy_model = None, quantize_bits = 64, ) -> None: """Federated Averaging strategy. Implementation based on https://arxiv.org/abs/1602.05629 Parameters ---------- fraction_fit : float, optional Fraction of clients used during training. Defaults to 0.1. fraction_eval : float, optional Fraction of clients used during validation. Defaults to 0.1. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_eval_clients : int, optional Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. eval_fn : Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure training. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional Initial global model parameters. """ super().__init__() if ( min_fit_clients > min_available_clients or min_eval_clients > min_available_clients ): log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW) self.fraction_fit = fraction_fit self.fraction_eval = fraction_eval self.min_fit_clients = min_fit_clients self.min_eval_clients = min_eval_clients self.min_available_clients = min_available_clients self.eval_fn = eval_fn self.on_fit_config_fn = on_fit_config_fn self.on_evaluate_config_fn = on_evaluate_config_fn self.accept_failures = accept_failures self.initial_parameters = initial_parameters # a dummy model used to determine dimensions of weights vector if quantization is used self.dummy_model = dummy_model self.q_bits = quantize_bits
def start_simulation( # pylint: disable=too-many-arguments *, client_fn: Callable[[str], Client], num_clients: Optional[int] = None, clients_ids: Optional[List[str]] = None, client_resources: Optional[Dict[str, int]] = None, num_rounds: int = 1, strategy: Optional[Strategy] = None, ray_init_args: Optional[Dict[str, Any]] = None, ) -> History: """Start a Ray-based Flower simulation server. Parameters ---------- client_fn : Callable[[str], Client] A function creating client instances. The function must take a single str argument called `cid`. It should return a single client instance. Note that the created client instances are ephemeral and will often be destroyed after a single method invocation. Since client instances are not long-lived, they should not attempt to carry state over method invocations. Any state required by the instance (model, dataset, hyperparameters, ...) should be (re-)created in either the call to `client_fn` or the call to any of the client methods (e.g., load evaluation data in the `evaluate` method itself). num_clients : Optional[int] The total number of clients in this simulation. This must be set if `clients_ids` is not set and vice-versa. clients_ids : Optional[List[str]] List `client_id`s for each client. This is only required if `num_clients` is not set. Setting both `num_clients` and `clients_ids` with `len(clients_ids)` not equal to `num_clients` generates an error. client_resources : Optional[Dict[str, int]] (default: None) CPU and GPU resources for a single client. Supported keys are `num_cpus` and `num_gpus`. Example: `{"num_cpus": 4, "num_gpus": 1}`. To understand the GPU utilization caused by `num_gpus`, consult the Ray documentation on GPU support. num_rounds : int (default: 1) The number of rounds to train. strategy : Optional[flwr.server.Strategy] (default: None) An implementation of the abstract base class `flwr.server.Strategy`. If no strategy is provided, then `start_server` will use `flwr.server.strategy.FedAvg`. ray_init_args : Optional[Dict[str, Any]] (default: None) Optional dictionary containing arguments for the call to `ray.init`. If ray_init_args is None (the default), Ray will be initialized with the following default args: { "ignore_reinit_error": True, "include_dashboard": False, } An empty dictionary can be used (ray_init_args={}) to prevent any arguments from being passed to ray.init. Returns: hist: flwr.server.history.History. Object containing metrics from training. """ cids: List[str] # clients_ids takes precedence if clients_ids is not None: if (num_clients is not None) and (len(clients_ids) != num_clients): log(ERROR, INVALID_ARGUMENTS_START_SIMULATION) sys.exit() else: cids = clients_ids else: if num_clients is None: log(ERROR, INVALID_ARGUMENTS_START_SIMULATION) sys.exit() else: cids = [str(x) for x in range(num_clients)] # Default arguments for Ray initialization if not ray_init_args: ray_init_args = { "ignore_reinit_error": True, "include_dashboard": False, } # Shut down Ray if it has already been initialized if ray.is_initialized(): ray.shutdown() # Initialize Ray ray.init(**ray_init_args) log( INFO, "Ray initialized with resources: %s", ray.cluster_resources(), ) # Initialize server and server config config = {"num_rounds": num_rounds} initialized_server, initialized_config = _init_defaults( None, config, strategy) log( INFO, "Starting Flower simulation running: %s", initialized_config, ) # Register one RayClientProxy object for each client with the ClientManager resources = client_resources if client_resources is not None else {} for cid in cids: client_proxy = RayClientProxy( client_fn=client_fn, cid=cid, resources=resources, ) initialized_server.client_manager().register(client=client_proxy) # Start training hist = _fl( server=initialized_server, config=initialized_config, force_final_distributed_eval=False, ) return hist
def main() -> None: """Download data.""" log(INFO, "Download Keyword Detection") tf_hotkey_partitioned.hotkey_load()
def start_client( server_address: str, client: Client, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[bytes] = None, ) -> None: """Start a Flower Client which connects to a gRPC server. Parameters ---------- server_address: str. The IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. client: flwr.client.Client. An implementation of the abstract base class `flwr.client.Client`. grpc_max_message_length: int (default: 536_870_912, this equals 512MB). The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train very large models might need to increase this value. Note that the Flower server needs to be started with the same value (see `flwr.server.start_server`), otherwise it will not know about the increased limit and block larger messages. root_certificates: bytes (default: None) The PEM-encoded root certificates as a byte string. If provided, a secure connection using the certificates will be established to a SSL-enabled Flower server. Returns ------- None Examples -------- Starting a client with insecure server connection: >>> start_client( >>> server_address=localhost:8080, >>> client=FlowerClient(), >>> ) Starting a SSL-enabled client: >>> from pathlib import Path >>> start_client( >>> server_address=localhost:8080, >>> client=FlowerClient(), >>> root_certificates=Path("/crts/root.pem").read_bytes(), >>> ) """ while True: sleep_duration: int = 0 with grpc_connection( server_address, max_message_length=grpc_max_message_length, root_certificates=root_certificates, ) as conn: receive, send = conn log(INFO, "Opened (insecure) gRPC connection") while True: server_message = receive() client_message, sleep_duration, keep_going = handle( client, server_message) send(client_message) if not keep_going: break if sleep_duration == 0: log(INFO, "Disconnect and shut down") break # Sleep and reconnect afterwards log( INFO, "Disconnect, then re-establish connection after %s second(s)", sleep_duration, ) time.sleep(sleep_duration)
def grpc_connection( server_address: str, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[bytes] = None, ) -> Iterator[Tuple[Callable[[], ServerMessage], Callable[[ClientMessage], None]]]: """Establish an insecure gRPC connection to a gRPC server. Parameters ---------- server_address : str The IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. grpc_max_message_length : int The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train very large models might need to increase this value. Note that the Flower server needs to be started with the same value (see `flwr.server.start_server`), otherwise it will not know about the increased limit and block larger messages. (default: 536_870_912, this equals 512MB) root_certificates : Optional[bytes] (default: None) The PEM-encoded root certificates as a byte string. If provided, a secure connection using the certificates will be established to a SSL-enabled Flower server. Returns ------- receive, send : Callable, Callable Examples -------- Establishing a SSL-enabled connection to the server: >>> from pathlib import Path >>> with grpc_connection( >>> server_address, >>> max_message_length=grpc_max_message_length, >>> root_certificates=Path("/crts/root.pem").read_bytes(), >>> ) as conn: >>> receive, send = conn >>> server_message = receive() >>> # do something here >>> send(client_message) """ channel_options = [ ("grpc.max_send_message_length", max_message_length), ("grpc.max_receive_message_length", max_message_length), ] if root_certificates is not None: ssl_channel_credentials = grpc.ssl_channel_credentials( root_certificates) channel = grpc.secure_channel(server_address, ssl_channel_credentials, options=channel_options) else: channel = grpc.insecure_channel(server_address, options=channel_options) channel.subscribe(on_channel_state_change) queue: Queue[ClientMessage] = Queue( # pylint: disable=unsubscriptable-object maxsize=1) stub = FlowerServiceStub(channel) server_message_iterator: Iterator[ServerMessage] = stub.Join( iter(queue.get, None)) receive: Callable[[], ServerMessage] = lambda: next(server_message_iterator) send: Callable[[ClientMessage], None] = lambda msg: queue.put(msg, block=False) try: yield (receive, send) finally: # Make sure to have a final channel.close() log(DEBUG, "Insecure gRPC channel closed")
def main() -> None: """Download data.""" log(INFO, "Download Fashion-MNIST") tf.keras.datasets.fashion_mnist.load_data()
def _fl(server: Server, config: Dict[str, int]) -> None: # Fit model hist = server.fit(num_rounds=config["num_rounds"]) log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) # Temporary workaround to force distributed evaluation server.strategy.eval_fn = None # type: ignore # Evaluate the final trained model res = server.evaluate_round(rnd=-1) if res is not None: loss, _, (results, failures) = res log(INFO, "app_evaluate: federated loss: %s", str(loss)) log( INFO, "app_evaluate: results %s", str([(res[0].cid, res[1]) for res in results]), ) log(INFO, "app_evaluate: failures %s", str(failures)) else: log(INFO, "app_evaluate: no evaluation result") # Graceful shutdown server.disconnect_all_clients()
def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server log(INFO, "server_setting: %s", server_setting) # Load evaluation data (_, _), (x_test, y_test) = tf_hotkey_partitioned.load_data(iid_fraction=0.0, num_partitions=1) if server_setting.dry_run: x_test = x_test[0:50] y_test = y_test[0:50] # Load model (for centralized evaluation) model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED) # Strategy eval_fn = get_eval_fn(model=model, num_classes=10, xy_test=(x_test, y_test)) on_fit_config_fn = get_on_fit_config_fn( lr_initial=server_setting.lr_initial, timeout=server_setting.training_round_timeout, partial_updates=server_setting.partial_updates, ) if server_setting.strategy == "fedavg": strategy = fl.server.strategy.FedAvg( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) if server_setting.strategy == "fast-and-slow": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fast-and-slow` strategy") strategy = fl.server.strategy.FastAndSlow( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, importance_sampling=server_setting.importance_sampling, dynamic_timeout=server_setting.dynamic_timeout, dynamic_timeout_percentile=0.9, alternating_timeout=server_setting.alternating_timeout, r_fast=1, r_slow=1, t_fast=math.ceil(0.5 * server_setting.training_round_timeout), t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "fedfs-v0": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fedfs-v0` strategy") t_fast = (math.ceil(0.5 * server_setting.training_round_timeout) if server_setting.training_round_timeout_short is None else server_setting.training_round_timeout_short) strategy = fl.server.strategy.FedFSv0( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, r_fast=1, r_slow=1, t_fast=t_fast, t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "qffedavg": strategy = fl.server.strategy.QFedAvg( q_param=0.2, qffl_learning_rate=0.1, fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) # Run server fl.server.start_server( DEFAULT_SERVER_ADDRESS, config={"num_rounds": server_setting.rounds}, strategy=strategy, )
def fit(self, num_rounds: int) -> History: """Run federated averaging for a number of rounds.""" history = History() # Initialize parameters log(INFO, "Initializing global parameters") self.parameters = self._get_initial_parameters() log(INFO, "Evaluating initial parameters") res = self.strategy.evaluate(parameters=self.parameters) if res is not None: log( INFO, "initial parameters (loss, other metrics): %s, %s", res[0], res[1], ) history.add_loss_centralized(rnd=0, loss=res[0]) history.add_metrics_centralized(rnd=0, metrics=res[1]) # Run federated learning for num_rounds log(INFO, "FL starting") start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): # Train model and replace previous global model res_fit = self.fit_round(rnd=current_round) if res_fit: parameters_prime, _, _ = res_fit # fit_metrics_aggregated if parameters_prime: self.parameters = parameters_prime # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(parameters=self.parameters) if res_cen is not None: loss_cen, metrics_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", current_round, loss_cen, metrics_cen, timeit.default_timer() - start_time, ) history.add_loss_centralized(rnd=current_round, loss=loss_cen) history.add_metrics_centralized(rnd=current_round, metrics=metrics_cen) # Evaluate model on a sample of available clients res_fed = self.evaluate_round(rnd=current_round) if res_fed: loss_fed, evaluate_metrics_fed, _ = res_fed if loss_fed: history.add_loss_distributed(rnd=current_round, loss=loss_fed) history.add_metrics_distributed( rnd=current_round, metrics=evaluate_metrics_fed) # Bookkeeping end_time = timeit.default_timer() elapsed = end_time - start_time log(INFO, "FL finished in %s", elapsed) return history
def start_server( # pylint: disable=too-many-arguments server_address: str = DEFAULT_SERVER_ADDRESS, server: Optional[Server] = None, config: Optional[Dict[str, int]] = None, strategy: Optional[Strategy] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, force_final_distributed_eval: bool = False, ) -> None: """Start a Flower server using the gRPC transport layer. Arguments: server_address: Optional[str] (default: `"[::]:8080"`). The IPv6 address of the server. server: Optional[flwr.server.Server] (default: None). An implementation of the abstract base class `flwr.server.Server`. If no instance is provided, then `start_server` will create one. config: Optional[Dict[str, int]] (default: None). The only currently supported values is `num_rounds`, so a full configuration object instructing the server to perform three rounds of federated learning looks like the following: `{"num_rounds": 3}`. strategy: Optional[flwr.server.Strategy] (default: None). An implementation of the abstract base class `flwr.server.Strategy`. If no strategy is provided, then `start_server` will use `flwr.server.strategy.FedAvg`. grpc_max_message_length: int (default: 536_870_912, this equals 512MB). The maximum length of gRPC messages that can be exchanged with the Flower clients. The default should be sufficient for most models. Users who train very large models might need to increase this value. Note that the Flower clients need to be started with the same value (see `flwr.client.start_client`), otherwise clients will not know about the increased limit and block larger messages. force_final_distributed_eval: bool (default: False). Forces a distributed evaulation to occur after the last training epoch when enabled. Returns: None. """ initialized_server, initialized_config = _init_defaults( server, config, strategy) # Start gRPC server grpc_server = start_insecure_grpc_server( client_manager=initialized_server.client_manager(), server_address=server_address, max_message_length=grpc_max_message_length, ) log( INFO, "Flower server running (insecure, %s rounds)", initialized_config["num_rounds"], ) _fl( server=initialized_server, config=initialized_config, force_final_distributed_eval=force_final_distributed_eval, ) # Stop the gRPC server grpc_server.stop(grace=1)
partial_updates: bool) -> Callable[[int], Dict[str, fl.common.Scalar]]: """Return a function which returns training configurations.""" def fit_config(rnd: int) -> Dict[str, fl.common.Scalar]: """Return a configuration with static batch size and (local) epochs.""" config: Dict[str, fl.common.Scalar] = { "epoch_global": str(rnd), "epochs": str(5), "batch_size": str(32), "lr_initial": str(lr_initial), "lr_decay": str(0.99), "partial_updates": "1" if partial_updates else "0", } if timeout is not None: config["timeout"] = str(timeout) return config return fit_config if __name__ == "__main__": # pylint: disable=broad-except try: main() except Exception as err: log(ERROR, "Fatal error in main") log(ERROR, err, exc_info=True, stack_info=True) # Raise the error again so the exit code is correct raise err
def start_server( server_address: str = DEFAULT_SERVER_ADDRESS, server: Optional[Server] = None, config: Optional[Dict[str, int]] = None, strategy: Optional[Strategy] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, ) -> None: """Start a Flower server using the gRPC transport layer. Arguments: server_address: Optional[str] (default: `"[::]:8080"`). The IPv6 address of the server. server: Optional[flwr.server.Server] (default: None). An implementation of the abstract base class `flwr.server.Server`. If no instance is provided, then `start_server` will create one. config: Optional[Dict[str, int]] (default: None). The only currently supported values is `num_rounds`, so a full configuration object instructing the server to perform three rounds of federated learning looks like the following: `{"num_rounds": 3}`. strategy: Optional[flwr.server.Strategy] (default: None). An implementation of the abstract base class `flwr.server.Strategy`. If no strategy is provided, then `start_server` will use `flwr.server.strategy.FedAvg`. grpc_max_message_length: int (default: 536_870_912, this equals 512MB). The maximum length of gRPC messages that can be exchanged with the Flower clients. The default should be sufficient for most models. Users who train very large models might need to increase this value. Note that the Flower clients need to be started with the same value (see `flwr.client.start_client`), otherwise clients will not know about the increased limit and block larger messages. Returns: None. """ # Create server instance if none was given if server is None: client_manager = SimpleClientManager() if strategy is None: strategy = FedAvg() server = Server(client_manager=client_manager, strategy=strategy) # Set default config values if config is None: config = {} if "num_rounds" not in config: config["num_rounds"] = 1 # Start gRPC server grpc_server = start_insecure_grpc_server( client_manager=server.client_manager(), server_address=server_address, max_message_length=grpc_max_message_length, ) log(INFO, "Flower server running (insecure, %s rounds)", config["num_rounds"]) # Fit model hist = server.fit(num_rounds=config["num_rounds"]) log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) log(INFO, "app_fit: accuracies_distributed %s", str(hist.accuracies_distributed)) log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) log(INFO, "app_fit: accuracies_centralized %s", str(hist.accuracies_centralized)) # Temporary workaround to force distributed evaluation server.strategy.eval_fn = None # type: ignore # Evaluate the final trained model res = server.evaluate(rnd=-1) if res is not None: loss, (results, failures) = res log(INFO, "app_evaluate: federated loss: %s", str(loss)) log( INFO, "app_evaluate: results %s", str([(res[0].cid, res[1]) for res in results]), ) log(INFO, "app_evaluate: failures %s", str(failures)) else: log(INFO, "app_evaluate: no evaluation result") # Graceful shutdown server.disconnect_all_clients() # Stop the gRPC server grpc_server.stop(1)
def create_instances( self, num_cpu: int, num_ram: float, timeout: int, num_instance: int = 1, gpu: bool = False, ) -> List[AdapterInstance]: """Create one or more EC2 instance(s) of the same type. Args: num_cpu (int): Number of instance vCPU (values in ec2_adapter.INSTANCE_TYPES_CPU or INSTANCE_TYPES_GPU) num_ram (int): RAM in GB (values in ec2_adapter.INSTANCE_TYPES_CPU or INSTANCE_TYPES_GPU) timeout (int): Timeout in minutes num_instance (int): Number of instances to start if currently available in EC2 """ # The instance will be set to terminate after stutdown # This is a fail safe in case something happens and the instances # are not correctly shutdown user_data = ["#!/bin/bash", f"sudo shutdown -P {timeout}"] user_data_str = "\n".join(user_data) instance_type, hourly_price = find_instance_type( num_cpu, num_ram, INSTANCE_TYPES_GPU if gpu else INSTANCE_TYPES_CPU) hourly_price_total = round(num_instance * hourly_price, 2) log( INFO, "Starting %s instances of type %s which in total will roughly cost $%s an hour.", num_instance, instance_type, hourly_price_total, ) result: EC2RunInstancesResult = self.ec2.run_instances( BlockDeviceMappings=[{ "DeviceName": "/dev/sda1", "Ebs": { "DeleteOnTermination": True } }], ImageId=self.image_id, # We always want an exact number of instances MinCount=num_instance, MaxCount=num_instance, InstanceType=instance_type, KeyName=self.key_name, IamInstanceProfile={"Name": "FlowerInstanceProfile"}, SubnetId=self.subnet_id, SecurityGroupIds=self.security_group_ids, TagSpecifications=self.tag_specifications, InstanceInitiatedShutdownBehavior="terminate", UserData=user_data_str, ) instance_ids = [ins["InstanceId"] for ins in result["Instances"]] # As soon as all instances status is "running" we have to check the InstanceStatus which # reports impaired functionality that stems from issues internal to the instance, such as # impaired reachability try: self._wait_until_instances_are_reachable(instance_ids=instance_ids) except EC2StatusTimeout: self.terminate_instances(instance_ids) raise EC2CreateInstanceFailure() return self.list_instances(instance_ids=instance_ids)
def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server log(INFO, "server_setting: %s", server_setting) # Load evaluation data (_, _), (x_test, y_test) = tf_cifar_partitioned.load_data( iid_fraction=0.0, num_partitions=1, cifar100=NUM_CLASSES == 100 ) if server_setting.dry_run: x_test = x_test[0:50] y_test = y_test[0:50] # Load model (for centralized evaluation) model = resnet50v2(input_shape=(32, 32, 3), num_classes=NUM_CLASSES, seed=SEED) # Create client_manager client_manager = fl.server.SimpleClientManager() # Strategy eval_fn = get_eval_fn( model=model, num_classes=NUM_CLASSES, xy_test=(x_test, y_test) ) fit_config_fn = get_on_fit_config_fn( lr_initial=server_setting.lr_initial, timeout=server_setting.training_round_timeout, partial_updates=server_setting.partial_updates, ) if server_setting.strategy == "fedavg": strategy = fl.server.strategy.FedAvg( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=fit_config_fn, ) if server_setting.strategy == "fast-and-slow": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fast-and-slow` strategy" ) strategy = fl.server.strategy.FastAndSlow( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=fit_config_fn, importance_sampling=server_setting.importance_sampling, dynamic_timeout=server_setting.dynamic_timeout, dynamic_timeout_percentile=0.8, alternating_timeout=server_setting.alternating_timeout, r_fast=1, r_slow=1, t_fast=math.ceil(0.5 * server_setting.training_round_timeout), t_slow=server_setting.training_round_timeout, ) # Run server server = fl.server.Server(client_manager=client_manager, strategy=strategy) fl.server.start_server( DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds}, )