def get_client_setting(setting: str, cid: str) -> ClientSetting: """Return client setting based on setting name and cid.""" for client_setting in get_setting(setting).clients: if client_setting.cid == cid: return client_setting raise ClientSettingNotFound()
def load_baseline_setting(baseline: str, setting: str) -> Baseline: """Return appropriate baseline setting.""" if baseline == "tf_cifar": return tf_cifar_settings.get_setting(setting) if baseline == "tf_fashion_mnist": return tf_fashion_mnist_settings.get_setting(setting) if baseline == "tf_hotkey": return tf_hotkey_settings.get_setting(setting) raise Exception("Setting not found.")
def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server log(INFO, "server_setting: %s", server_setting) # Load evaluation data (_, _), (x_test, y_test) = tf_hotkey_partitioned.load_data(iid_fraction=0.0, num_partitions=1) if server_setting.dry_run: x_test = x_test[0:50] y_test = y_test[0:50] # Load model (for centralized evaluation) model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED) # Strategy eval_fn = get_eval_fn(model=model, num_classes=10, xy_test=(x_test, y_test)) on_fit_config_fn = get_on_fit_config_fn( lr_initial=server_setting.lr_initial, timeout=server_setting.training_round_timeout, partial_updates=server_setting.partial_updates, ) if server_setting.strategy == "fedavg": strategy = fl.server.strategy.FedAvg( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) if server_setting.strategy == "fast-and-slow": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fast-and-slow` strategy") strategy = fl.server.strategy.FastAndSlow( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, importance_sampling=server_setting.importance_sampling, dynamic_timeout=server_setting.dynamic_timeout, dynamic_timeout_percentile=0.9, alternating_timeout=server_setting.alternating_timeout, r_fast=1, r_slow=1, t_fast=math.ceil(0.5 * server_setting.training_round_timeout), t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "fedfs-v0": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fedfs-v0` strategy") t_fast = (math.ceil(0.5 * server_setting.training_round_timeout) if server_setting.training_round_timeout_short is None else server_setting.training_round_timeout_short) strategy = fl.server.strategy.FedFSv0( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, r_fast=1, r_slow=1, t_fast=t_fast, t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "qffedavg": strategy = fl.server.strategy.QFedAvg( q_param=0.2, qffl_learning_rate=0.1, fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) # Run server fl.server.start_server( DEFAULT_SERVER_ADDRESS, config={"num_rounds": server_setting.rounds}, strategy=strategy, )
def run(baseline: str, setting: str, adapter: str) -> None: """Run baseline.""" print(f"Starting baseline with {setting} settings.") wheel_remote_path = (f"/root/{WHEEL_FILENAME}" if adapter == "docker" else f"/home/ubuntu/{WHEEL_FILENAME}") if baseline == "tf_cifar": settings = tf_cifar_settings.get_setting(setting) elif baseline == "tf_fashion_mnist": settings = tf_fashion_mnist_settings.get_setting(setting) elif baseline == "tf_hotkey": settings = tf_hotkey_settings.get_setting(setting) else: raise Exception("Setting not found.") # Get instances and add a logserver to the list instances = settings.instances instances.append( Instance(name="logserver", group="logserver", num_cpu=2, num_ram=2)) # Configure cluster log(INFO, "(1/9) Configure cluster.") cluster = configure_cluster(adapter, instances, baseline, setting) # Start the cluster; this takes some time log(INFO, "(2/9) Start cluster.") cluster.start() # Upload wheel to all instances log(INFO, "(3/9) Upload wheel to all instances.") cluster.upload_all(WHEEL_LOCAL_PATH, wheel_remote_path) # Install the wheel on all instances log(INFO, "(4/9) Install wheel on all instances.") cluster.exec_all(command.install_wheel(wheel_remote_path)) # Download datasets in server and clients log(INFO, "(5/9) Download dataset on server and clients.") cluster.exec_all(command.download_dataset(baseline=baseline), groups=["server", "clients"]) # Start logserver log(INFO, "(6/9) Start logserver.") logserver = cluster.get_instance("logserver") cluster.exec( logserver.name, command.start_logserver( logserver_s3_bucket=CONFIG.get("aws", "logserver_s3_bucket"), logserver_s3_key=f"{baseline}_{setting}_{now()}.log", ), ) # Start Flower server on Flower server instances log(INFO, "(7/9) Start server.") cluster.exec( "server", command.start_server( log_host=f"{logserver.private_ip}:8081", baseline=baseline, setting=setting, ), ) # Start Flower clients log(INFO, "(8/9) Start clients.") server = cluster.get_instance("server") with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Start the load operations and mark each future with its URL concurrent.futures.wait([ executor.submit( cluster.exec, client_setting.instance_name, command.start_client( log_host=f"{logserver.private_ip}:8081", server_address=f"{server.private_ip}:8080", baseline=baseline, setting=setting, cid=client_setting.cid, ), ) for client_setting in settings.clients ]) # Shutdown server and client instance after 10min if not at least one Flower # process is running it log(INFO, "(9/9) Start shutdown watcher script.") cluster.exec_all(command.watch_and_shutdown("flower", adapter)) # Give user info how to tail logfile private_key = (DOCKER_PRIVATE_KEY if adapter == "docker" else path.expanduser(CONFIG.get("ssh", "private_key"))) log( INFO, "If you would like to tail the central logfile run:\n\n\t%s\n", command.tail_logfile(adapter, private_key, logserver), )