def insecure_grpc_connection( server_address: str, ) -> Iterator[Tuple[Callable[[], ServerMessage], Callable[[ClientMessage], None]]]: """Establish an insecure gRPC connection to a gRPC server.""" channel = grpc.insecure_channel( server_address, options=[ ("grpc.max_send_message_length", 256 * 1024 * 1024), ("grpc.max_receive_message_length", 256 * 1024 * 1024), ], ) channel.subscribe(on_channel_state_change) queue: Queue[ClientMessage] = Queue( # pylint: disable=unsubscriptable-object maxsize=1) stub = FlowerServiceStub(channel) # type: ignore server_message_iterator: Iterator[ServerMessage] = stub.Join( iter(queue.get, None)) receive: Callable[[], ServerMessage] = lambda: next(server_message_iterator) send: Callable[[ClientMessage], None] = lambda msg: queue.put(msg, block=False) try: yield (receive, send) finally: # Make sure to have a final channel.close() log(DEBUG, "Insecure gRPC channel closed")
def start_client(server_address: str, client: Client) -> None: """Start a Flower Client which connects to a gRPC server.""" with insecure_grpc_connection(server_address) as conn: receive, send = conn log(INFO, "Opened (insecure) gRPC connection") while True: server_message = receive() client_message = handle(client, server_message) send(client_message)
def exec(self, instance_name: str, command: str) -> ExecInfo: """Run command on instance and return stdout.""" log(DEBUG, "Exec on %s: %s", instance_name, command) instance = self.get_instance(instance_name) with ssh_connection(instance, self.ssh_credentials) as client: _, stdout, stderr = client.exec_command(command) stdout = stdout.readlines() stderr = stderr.readlines() return stdout, stderr
def _fs_based_sampling( self, sample_size: int, client_manager: ClientManager, fast_round: bool ) -> List[ClientProxy]: """Sample clients with 1/k * c/m in fast rounds and 1 - c/m in slow rounds.""" all_clients: Dict[str, ClientProxy] = client_manager.all() k = len(all_clients) cid_idx: Dict[int, str] = {} raw: List[float] = [] for idx, (cid, _) in enumerate(all_clients.items()): cid_idx[idx] = cid if cid in self.contributions.keys(): # Previously selected clients contribs: List[Tuple[int, int, int]] = self.contributions[cid] # pylint: disable-msg=invalid-name if self.use_past_contributions: cs = [c for _, c, _ in contribs] ms = [m for _, _, m in contribs] c_over_m = sum(cs) / sum(ms) else: _, c, m = contribs[-1] c_over_m = c / m # pylint: enable-msg=invalid-name if fast_round: importance = (1 / k) * c_over_m + E else: importance = 1 - c_over_m + E else: # Previously unselected clients if fast_round: importance = 1 / k else: importance = 1 raw.append(importance) log( DEBUG, "FedFS _fs_based_sampling, sample %s clients, raw %s", str(sample_size), str(raw), ) return normalize_and_sample( all_clients=all_clients, cid_idx=cid_idx, raw=np.array(raw), sample_size=sample_size, use_softmax=False, )
def main() -> None: """Download data.""" parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--cifar", type=int, choices=[10, 100], default=10, help="CIFAR version, allowed values: 10 or 100 (default: 10)", ) args = parser.parse_args() log(INFO, "Download CIFAR-%s", args.cifar) # Load model and data download_data(num_classes=args.cifar)
def fit(self, ins: fl.FitIns) -> fl.FitRes: weights: fl.Weights = fl.parameters_to_weights(ins[0]) config = ins[1] log( DEBUG, "fit on %s (examples: %s), config %s", self.cid, self.num_examples_train, config, ) # Training configuration # epoch_global = int(config["epoch_global"]) epochs = int(config["epochs"]) batch_size = int(config["batch_size"]) # lr_initial = float(config["lr_initial"]) # lr_decay = float(config["lr_decay"]) timeout = int(config["timeout"]) if "timeout" in config else None partial_updates = bool(int(config["partial_updates"])) # Use provided weights to update the local model self.model.set_weights(weights) # Train the local model using the local dataset completed, fit_duration, num_examples = custom_fit( model=self.model, dataset=self.ds_train, num_epochs=epochs, batch_size=batch_size, callbacks=[], delay_factor=self.delay_factor, timeout=timeout, ) log(DEBUG, "client %s had fit_duration %s", self.cid, fit_duration) # Compute the maximum number of examples which could have been processed num_examples_ceil = self.num_examples_train * epochs if not completed and not partial_updates: # Return empty update if local update could not be completed in time parameters = fl.weights_to_parameters([]) else: # Return the refined weights and the number of examples used for training parameters = fl.weights_to_parameters(self.model.get_weights()) return parameters, num_examples, num_examples_ceil, fit_duration
def main() -> None: """Load data, create and start CIFAR-10/100 client.""" args = parse_args() client_setting = get_client_setting(args.setting, args.cid) # Configure logger configure(identifier=f"client:{client_setting.cid}", host=args.log_host) log(INFO, "Starting client, settings: %s", client_setting) # Load model model = resnet50v2(input_shape=(32, 32, 3), num_classes=NUM_CLASSES, seed=SEED) # Load local data partition (xy_train_partitions, xy_test_partitions), _ = tf_cifar_partitioned.load_data( iid_fraction=client_setting.iid_fraction, num_partitions=client_setting.num_clients, cifar100=False, ) x_train, y_train = xy_train_partitions[client_setting.partition] x_test, y_test = xy_test_partitions[client_setting.partition] if client_setting.dry_run: x_train = x_train[0:100] y_train = y_train[0:100] x_test = x_test[0:50] y_test = y_test[0:50] # Start client client = VisionClassificationClient( client_setting.cid, model, (x_train, y_train), (x_test, y_test), client_setting.delay_factor, NUM_CLASSES, augment=True, augment_horizontal_flip=True, augment_offset=2, ) fl.app.client.start_client(args.server_address, client)
def main() -> None: """Load data, create and start Fashion-MNIST client.""" args = parse_args() client_setting = get_client_setting(args.setting, args.cid) # Configure logger configure(identifier=f"client:{client_setting.cid}", host=args.log_host) log(INFO, "Starting client, settings: %s", client_setting) # Load model model = orig_cnn(input_shape=(28, 28, 1), seed=SEED) # Load local data partition ( (xy_train_partitions, xy_test_partitions), _, ) = tf_fashion_mnist_partitioned.load_data( iid_fraction=client_setting.iid_fraction, num_partitions=client_setting.num_clients, ) x_train, y_train = xy_train_partitions[client_setting.partition] x_test, y_test = xy_test_partitions[client_setting.partition] if client_setting.dry_run: x_train = x_train[0:100] y_train = y_train[0:100] x_test = x_test[0:50] y_test = y_test[0:50] # Start client client = VisionClassificationClient( client_setting.cid, model, (x_train, y_train), (x_test, y_test), client_setting.delay_factor, 10, augment=True, augment_horizontal_flip=False, augment_offset=1, ) fl.app.client.start_client(args.server_address, client)
def on_configure_fit( self, rnd: int, weights: Weights, client_manager: ClientManager ) -> List[Tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" # Block until `min_num_clients` are available sample_size, min_num_clients = self.num_fit_clients( client_manager.num_available() ) success = client_manager.wait_for( num_clients=min_num_clients, timeout=WAIT_TIMEOUT ) if not success: # Do not continue if not enough clients are available log( INFO, "FedFS: not enough clients available after timeout %s", WAIT_TIMEOUT, ) return [] # Sample clients clients = self._contribution_based_sampling( sample_size=sample_size, client_manager=client_manager ) # Prepare parameters and config parameters = weights_to_parameters(weights) config = {} if self.on_fit_config_fn is not None: # Use custom fit config function if provided config = self.on_fit_config_fn(rnd) # Set timeout for this round use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow) config["timeout"] = str(self.t_fast if use_fast_timeout else self.t_slow) # Fit instructions fit_ins = (parameters, config) # Return client/config pairs return [(client, fit_ins) for client in clients]
def evaluate(self, ins: fl.EvaluateIns) -> fl.EvaluateRes: weights = fl.parameters_to_weights(ins[0]) config = ins[1] log( DEBUG, "evaluate on %s (examples: %s), config %s", self.cid, self.num_examples_test, config, ) # Use provided weights to update the local model self.model.set_weights(weights) # Evaluate the updated model on the local dataset loss, acc = keras_evaluate(self.model, self.ds_test, batch_size=self.num_examples_test) # Return the number of evaluation examples and the evaluation result (loss) return self.num_examples_test, loss, acc
def fit_round(self, rnd: int) -> Optional[Weights]: """Perform a single round of federated averaging.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.on_configure_fit( rnd=rnd, weights=self.weights, client_manager=self._client_manager) log( DEBUG, "fit_round: strategy sampled %s clients", len(client_instructions), ) if not client_instructions: log(INFO, "fit_round: no clients sampled, cancel fit") return None # Collect training results from all clients participating in this round results, failures = fit_clients(client_instructions) log( DEBUG, "fit_round received %s results and %s failures", len(results), len(failures), ) # Aggregate training results return self.strategy.on_aggregate_fit(rnd, results, failures)
def evaluate( self, rnd: int ) -> Optional[Tuple[Optional[float], EvaluateResultsAndFailures]]: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.on_configure_evaluate( rnd=rnd, weights=self.weights, client_manager=self._client_manager) if not client_instructions: log(INFO, "evaluate: no clients sampled, cancel federated evaluation") return None log( DEBUG, "evaluate: strategy sampled %s clients", len(client_instructions), ) # Evaluate current global weights on those clients results_and_failures = evaluate_clients(client_instructions) results, failures = results_and_failures log( DEBUG, "evaluate received %s results and %s failures", len(results), len(failures), ) # Aggregate the evaluation results loss_aggregated = self.strategy.on_aggregate_evaluate( rnd, results, failures) return loss_aggregated, results_and_failures
def start(self) -> None: """Start the instance.""" instance_groups = group_instances_by_specs(self.instances) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [ executor.submit( create_instances, self.adapter, instance_group, self.timeout ) for instance_group in instance_groups ] concurrent.futures.wait(futures) try: for future in futures: future.result() # pylint: disable=broad-except except Exception as exc: log( ERROR, "Failed to start the cluster completely. Shutting down...", ) log(ERROR, exc) for future in futures: future.cancel() self.terminate() raise StartFailed() for ins in self.instances: log(DEBUG, ins)
def exec_all( self, command: str, groups: Optional[List[str]] = None ) -> Dict[str, ExecInfo]: """Run command on all instances. If provided filter by group.""" instance_names = self.get_instance_names(groups) results: Dict[str, ExecInfo] = {} with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Start the load operations and mark each future with its URL future_to_result = { executor.submit(self.exec, instance_name, command): instance_name for instance_name in instance_names } for future in concurrent.futures.as_completed(future_to_result): instance_name = future_to_result[future] try: results[instance_name] = future.result() # pylint: disable=broad-except except Exception as exc: log(ERROR, (instance_name, exc)) return results
def upload_all( self, local_path: str, remote_path: str ) -> Dict[str, SFTPAttributes]: """Upload file to all instances.""" results: Dict[str, SFTPAttributes] = {} with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Start the load operations and mark each future with its URL future_to_result = { executor.submit( self.upload, instance_name, local_path, remote_path ): instance_name for instance_name in self.get_instance_names() } for future in concurrent.futures.as_completed(future_to_result): instance_name = future_to_result[future] try: results[instance_name] = future.result() # pylint: disable=broad-except except Exception as exc: log(ERROR, (instance_name, exc)) return results
"""Return a function which returns training configurations.""" def fit_config(rnd: int) -> Dict[str, str]: """Return a configuration with static batch size and (local) epochs.""" config = { "epoch_global": str(rnd), "epochs": str(5), "batch_size": str(10), "lr_initial": str(lr_initial), "lr_decay": str(0.99), "partial_updates": "1" if partial_updates else "0", } if timeout is not None: config["timeout"] = str(timeout) return config return fit_config if __name__ == "__main__": # pylint: disable=broad-except try: main() except Exception as err: log(ERROR, "Fatal error in main") log(ERROR, err, exc_info=True, stack_info=True) # Raise the error again so the exit code is correct raise err
def main() -> None: """Download data.""" log(INFO, "Download Keyword Detection") tf_hotkey_partitioned.hotkey_load()
def custom_fit( model: tf.keras.Model, dataset: tf.data.Dataset, num_epochs: int, batch_size: int, callbacks: List[tf.keras.callbacks.Callback], delay_factor: float = 0.0, timeout: Optional[int] = None, ) -> Tuple[bool, float, int]: """Train the model using a custom training loop.""" ds_train = dataset.batch(batch_size=batch_size, drop_remainder=False) # Keep results for plotting train_loss_results = [] train_accuracy_results = [] # Optimizer optimizer = tf.keras.optimizers.Adam() fit_begin = timeit.default_timer() num_examples = 0 for epoch in range(num_epochs): log(INFO, "Starting epoch %s", epoch) epoch_loss_avg = tf.keras.metrics.Mean() epoch_accuracy = tf.keras.metrics.CategoricalAccuracy() # Single loop over the dataset batch_begin = timeit.default_timer() num_examples_batch = 0 for batch, (x, y) in enumerate(ds_train): num_examples_batch += len(x) # Optimize the model loss_value, grads = grad(model, x, y) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # Track progress epoch_loss_avg.update_state( loss_value) # Add the current batch loss epoch_accuracy.update_state(y, model(x, training=True)) # Track the number of examples used for training num_examples += x.shape[0] # Delay batch_duration = timeit.default_timer() - batch_begin if delay_factor > 0.0: time.sleep(batch_duration * delay_factor) # Progress log if batch % 100 == 0: log( INFO, "Batch %s: loss %s (%s examples processed, batch duration: %s)", batch, loss_value, num_examples_batch, batch_duration, ) # Timeout if timeout is not None: fit_duration = timeit.default_timer() - fit_begin if fit_duration > timeout: log(INFO, "client timeout") return (False, fit_duration, num_examples) batch_begin = timeit.default_timer() # End epoch train_loss_results.append(epoch_loss_avg.result()) train_accuracy_results.append(epoch_accuracy.result()) log( INFO, "Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format( epoch, epoch_loss_avg.result(), epoch_accuracy.result()), ) fit_duration = timeit.default_timer() - fit_begin return True, fit_duration, num_examples
def fit(self, num_rounds: int) -> History: """Run federated averaging for a number of rounds.""" history = History() # Initialize weights by asking one client to return theirs self.weights = self._get_initial_weights() res = self.strategy.evaluate(weights=self.weights) if res is not None: log( INFO, "initial weights (loss/accuracy): %s, %s", res[0], res[1], ) history.add_loss_centralized(rnd=0, loss=res[0]) history.add_accuracy_centralized(rnd=0, acc=res[1]) # Run federated learning for num_rounds log(INFO, "[TIME] FL starting") start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): # Train model and replace previous global model weights_prime = self.fit_round(rnd=current_round) if weights_prime is not None: self.weights = weights_prime # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(weights=self.weights) if res_cen is not None: loss_cen, acc_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", current_round, loss_cen, acc_cen, timeit.default_timer() - start_time, ) history.add_loss_centralized(rnd=current_round, loss=loss_cen) history.add_accuracy_centralized(rnd=current_round, acc=acc_cen) # Evaluate model on a sample of available clients res_fed = self.evaluate(rnd=current_round) if res_fed is not None and res_fed[0] is not None: loss_fed, _ = res_fed history.add_loss_distributed(rnd=current_round, loss=cast(float, loss_fed)) # Conclude round loss = res_cen[0] if res_cen is not None else None acc = res_cen[1] if res_cen is not None else None should_continue = self.strategy.on_conclude_round( current_round, loss, acc) if not should_continue: break end_time = timeit.default_timer() elapsed = end_time - start_time log(INFO, "[TIME] FL finished in %s", elapsed) return history
def on_configure_fit( self, rnd: int, weights: Weights, client_manager: ClientManager ) -> List[Tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" # Block until `min_num_clients` are available sample_size, min_num_clients = self.num_fit_clients( client_manager.num_available() ) success = client_manager.wait_for( num_clients=min_num_clients, timeout=WAIT_TIMEOUT ) if not success: # Do not continue if not enough clients are available log( INFO, "FedFS: not enough clients available after timeout %s", WAIT_TIMEOUT, ) return [] # Sample clients if rnd == 1: # Sample with 1/k in the first round log( DEBUG, "FedFS round %s, sample %s clients with 1/k", str(rnd), str(sample_size), ) clients = self._one_over_k_sampling( sample_size=sample_size, client_manager=client_manager ) else: fast_round = is_fast_round(rnd - 1, r_fast=self.r_fast, r_slow=self.r_slow) log( DEBUG, "FedFS round %s, sample %s clients, fast_round %s", str(rnd), str(sample_size), str(fast_round), ) clients = self._fs_based_sampling( sample_size=sample_size, client_manager=client_manager, fast_round=fast_round, ) # Prepare parameters and config parameters = weights_to_parameters(weights) config = {} if self.on_fit_config_fn is not None: # Use custom fit config function if provided config = self.on_fit_config_fn(rnd) # Set timeout for this round if self.durations: candidates = timeout_candidates( durations=self.durations, max_timeout=self.t_max, ) timeout = next_timeout( candidates=candidates, percentile=self.dynamic_timeout_percentile, ) config["timeout"] = str(timeout) else: # Initial round has not past durations, use max_timeout config["timeout"] = str(self.t_max) # Fit instructions fit_ins = (parameters, config) # Return client/config pairs return [(client, fit_ins) for client in clients]
def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity)
def main() -> None: """Download data.""" log(INFO, "Download Fashion-MNIST") tf.keras.datasets.fashion_mnist.load_data()
def create_instances( self, num_cpu: int, num_ram: float, timeout: int, num_instance: int = 1, gpu: bool = False, ) -> List[AdapterInstance]: """Create one or more EC2 instance(s) of the same type. Args: num_cpu (int): Number of instance vCPU (values in ec2_adapter.INSTANCE_TYPES_CPU or INSTANCE_TYPES_GPU) num_ram (int): RAM in GB (values in ec2_adapter.INSTANCE_TYPES_CPU or INSTANCE_TYPES_GPU) timeout (int): Timeout in minutes num_instance (int): Number of instances to start if currently available in EC2 """ # The instance will be set to terminate after stutdown # This is a fail safe in case something happens and the instances # are not correctly shutdown user_data = ["#!/bin/bash", f"sudo shutdown -P {timeout}"] user_data_str = "\n".join(user_data) instance_type, hourly_price = find_instance_type( num_cpu, num_ram, INSTANCE_TYPES_GPU if gpu else INSTANCE_TYPES_CPU) hourly_price_total = round(num_instance * hourly_price, 2) log( INFO, "Starting %s instances of type %s which in total will roughly cost $%s an hour.", num_instance, instance_type, hourly_price_total, ) result: EC2RunInstancesResult = self.ec2.run_instances( ImageId=self.image_id, # We always want an exact number of instances MinCount=num_instance, MaxCount=num_instance, InstanceType=instance_type, KeyName=self.key_name, IamInstanceProfile={"Name": "FlowerInstanceProfile"}, SubnetId=self.subnet_id, SecurityGroupIds=self.security_group_ids, TagSpecifications=self.tag_specifications, InstanceInitiatedShutdownBehavior="terminate", UserData=user_data_str, ) instance_ids = [ins["InstanceId"] for ins in result["Instances"]] # As soon as all instances status is "running" we have to check the InstanceStatus which # reports impaired functionality that stems from issues internal to the instance, such as # impaired reachability try: self._wait_until_instances_are_reachable(instance_ids=instance_ids) except EC2StatusTimeout: self.terminate_instances(instance_ids) raise EC2CreateInstanceFailure() return self.list_instances(instance_ids=instance_ids)
def start_server( server_address: str = DEFAULT_SERVER_ADDRESS, server: Optional[Server] = None, config: Optional[Dict[str, int]] = None, strategy: Optional[Strategy] = None, ) -> None: """Start a Flower server using the gRPC transport layer.""" # Create server instance if none was given if server is None: client_manager = SimpleClientManager() if strategy is None: strategy = FedAvg() server = Server(client_manager=client_manager, strategy=strategy) # Set default config values if config is None: config = {} if "num_rounds" not in config: config["num_rounds"] = 1 # Start gRPC server grpc_server = start_insecure_grpc_server( client_manager=server.client_manager(), server_address=server_address) log(INFO, "Flower server running (insecure, %s rounds)", config["num_rounds"]) # Fit model hist = server.fit(num_rounds=config["num_rounds"]) log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) log(INFO, "app_fit: accuracies_distributed %s", str(hist.accuracies_distributed)) log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) log(INFO, "app_fit: accuracies_centralized %s", str(hist.accuracies_centralized)) # Temporary workaround to force distributed evaluation server.strategy.eval_fn = None # type: ignore # Evaluate the final trained model res = server.evaluate(rnd=-1) if res is not None: loss, (results, failures) = res log(INFO, "app_evaluate: federated loss: %s", str(loss)) log( INFO, "app_evaluate: results %s", str([(res[0].cid, res[1]) for res in results]), ) log(INFO, "app_evaluate: failures %s", str(failures)) else: log(INFO, "app_evaluate: no evaluation result") # Stop the gRPC server grpc_server.stop(1)
def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server log(INFO, "server_setting: %s", server_setting) # Load evaluation data (_, _), (x_test, y_test) = tf_fashion_mnist_partitioned.load_data( iid_fraction=0.0, num_partitions=1 ) if server_setting.dry_run: x_test = x_test[0:50] y_test = y_test[0:50] # Load model (for centralized evaluation) model = orig_cnn(input_shape=(28, 28, 1), seed=SEED) # Create client_manager client_manager = fl.SimpleClientManager() # Strategy eval_fn = get_eval_fn(model=model, num_classes=10, xy_test=(x_test, y_test)) on_fit_config_fn = get_on_fit_config_fn( lr_initial=server_setting.lr_initial, timeout=server_setting.training_round_timeout, partial_updates=server_setting.partial_updates, ) if server_setting.strategy == "fedavg": strategy = fl.strategy.FedAvg( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) if server_setting.strategy == "fast-and-slow": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fast-and-slow` strategy" ) t_fast = ( math.ceil(0.5 * server_setting.training_round_timeout) if server_setting.training_round_timeout_short is None else server_setting.training_round_timeout_short ) strategy = fl.strategy.FastAndSlow( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, importance_sampling=server_setting.importance_sampling, dynamic_timeout=server_setting.dynamic_timeout, dynamic_timeout_percentile=0.8, alternating_timeout=server_setting.alternating_timeout, r_fast=1, r_slow=1, t_fast=t_fast, t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "fedfs-v0": if server_setting.training_round_timeout is None: raise ValueError("No `training_round_timeout` set for `fedfs-v0` strategy") t_fast = ( math.ceil(0.5 * server_setting.training_round_timeout) if server_setting.training_round_timeout_short is None else server_setting.training_round_timeout_short ) strategy = fl.strategy.FedFSv0( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, r_fast=1, r_slow=1, t_fast=t_fast, t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "fedfs-v1": if server_setting.training_round_timeout is None: raise ValueError("No `training_round_timeout` set for `fedfs-v1` strategy") strategy = fl.strategy.FedFSv1( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, dynamic_timeout_percentile=0.8, r_fast=1, r_slow=1, t_max=server_setting.training_round_timeout, use_past_contributions=True, ) if server_setting.strategy == "qffedavg": strategy = fl.strategy.QffedAvg( q_param=0.2, qffl_learning_rate=0.1, fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) # Run server log(INFO, "Instantiating server, strategy: %s", str(strategy)) server = fl.Server(client_manager=client_manager, strategy=strategy) fl.app.server.start_server( DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds}, )
def run(baseline: str, setting: str, adapter: str) -> None: """Run baseline.""" print(f"Starting baseline with {setting} settings.") wheel_remote_path = (f"/root/{WHEEL_FILENAME}" if adapter == "docker" else f"/home/ubuntu/{WHEEL_FILENAME}") if baseline == "tf_cifar": settings = tf_cifar_settings.get_setting(setting) elif baseline == "tf_fashion_mnist": settings = tf_fashion_mnist_settings.get_setting(setting) elif baseline == "tf_hotkey": settings = tf_hotkey_settings.get_setting(setting) else: raise Exception("Setting not found.") # Get instances and add a logserver to the list instances = settings.instances instances.append( Instance(name="logserver", group="logserver", num_cpu=2, num_ram=2)) # Configure cluster log(INFO, "(1/9) Configure cluster.") cluster = configure_cluster(adapter, instances, baseline, setting) # Start the cluster; this takes some time log(INFO, "(2/9) Start cluster.") cluster.start() # Upload wheel to all instances log(INFO, "(3/9) Upload wheel to all instances.") cluster.upload_all(WHEEL_LOCAL_PATH, wheel_remote_path) # Install the wheel on all instances log(INFO, "(4/9) Install wheel on all instances.") cluster.exec_all(command.install_wheel(wheel_remote_path)) # Download datasets in server and clients log(INFO, "(5/9) Download dataset on server and clients.") cluster.exec_all(command.download_dataset(baseline=baseline), groups=["server", "clients"]) # Start logserver log(INFO, "(6/9) Start logserver.") logserver = cluster.get_instance("logserver") cluster.exec( logserver.name, command.start_logserver( logserver_s3_bucket=CONFIG.get("aws", "logserver_s3_bucket"), logserver_s3_key=f"{baseline}_{setting}_{now()}.log", ), ) # Start Flower server on Flower server instances log(INFO, "(7/9) Start server.") cluster.exec( "server", command.start_server( log_host=f"{logserver.private_ip}:8081", baseline=baseline, setting=setting, ), ) # Start Flower clients log(INFO, "(8/9) Start clients.") server = cluster.get_instance("server") with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Start the load operations and mark each future with its URL concurrent.futures.wait([ executor.submit( cluster.exec, client_setting.instance_name, command.start_client( log_host=f"{logserver.private_ip}:8081", server_address=f"{server.private_ip}:8080", baseline=baseline, setting=setting, cid=client_setting.cid, ), ) for client_setting in settings.clients ]) # Shutdown server and client instance after 10min if not at least one Flower # process is running it log(INFO, "(9/9) Start shutdown watcher script.") cluster.exec_all(command.watch_and_shutdown("flower", adapter)) # Give user info how to tail logfile private_key = (DOCKER_PRIVATE_KEY if adapter == "docker" else path.expanduser(CONFIG.get("ssh", "private_key"))) log( INFO, "If you would like to tail the central logfile run:\n\n\t%s\n", command.tail_logfile(adapter, private_key, logserver), )