def main() -> None: """Start Flower baseline.""" parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--baseline", type=str, required=True, choices=["tf_cifar", "tf_fashion_mnist", "tf_hotkey"], help="Name of baseline name to run.", ) parser.add_argument( "--setting", type=str, required=True, choices=list( set( list(tf_cifar_settings.SETTINGS.keys()) + list(tf_fashion_mnist_settings.SETTINGS.keys()) + list(tf_hotkey_settings.SETTINGS.keys()))), help="Name of setting to run.", ) parser.add_argument( "--adapter", type=str, required=True, choices=["docker", "ec2"], help="Set adapter to be used.", ) args = parser.parse_args() # Configure logger configure(f"flower_{args.baseline}_{args.setting}") run(baseline=args.baseline, setting=args.setting, adapter=args.adapter)
def main() -> None: """Start Flower baseline.""" parser = argparse.ArgumentParser(description="Flower") # When adding a new setting make sure to modify the load_baseline_setting function possible_baselines = ["tf_cifar", "tf_fashion_mnist", "tf_hotkey"] possible_settings = [] all_settings = [ list(tf_cifar_settings.SETTINGS.keys()), list(tf_fashion_mnist_settings.SETTINGS.keys()), list(tf_hotkey_settings.SETTINGS.keys()), ] # Show only relevant settings based on baseline as choices # for --setting parameter baseline_arg = [arg for arg in sys.argv if "--baseline" in arg] if len(baseline_arg) > 0: selected_baseline = baseline_arg[0].split("=")[1] idx = possible_baselines.index(selected_baseline) possible_settings = all_settings[idx] parser.add_argument( "--baseline", type=str, required=True, choices=possible_baselines, help="Name of baseline name to run.", ) parser.add_argument( "--setting", type=str, required=True, choices=possible_settings, help="Name of setting to run.", ) parser.add_argument( "--adapter", type=str, required=True, choices=["docker", "ec2"], help="Set adapter to be used.", ) args = parser.parse_args() # Configure logger configure(f"flower_{args.baseline}_{args.setting}") run(baseline=args.baseline, setting=args.setting, adapter=args.adapter)
def main() -> None: """Load data, create and start CIFAR-10/100 client.""" args = parse_args() client_setting = get_client_setting(args.setting, args.cid) # Configure logger configure(identifier=f"client:{client_setting.cid}", host=args.log_host) log(INFO, "Starting client, settings: %s", client_setting) # Load model model = resnet50v2(input_shape=(32, 32, 3), num_classes=NUM_CLASSES, seed=SEED) # Load local data partition (xy_train_partitions, xy_test_partitions), _ = tf_cifar_partitioned.load_data( iid_fraction=client_setting.iid_fraction, num_partitions=client_setting.num_clients, cifar100=False, ) x_train, y_train = xy_train_partitions[client_setting.partition] x_test, y_test = xy_test_partitions[client_setting.partition] if client_setting.dry_run: x_train = x_train[0:100] y_train = y_train[0:100] x_test = x_test[0:50] y_test = y_test[0:50] # Start client client = VisionClassificationClient( client_setting.cid, model, (x_train, y_train), (x_test, y_test), client_setting.delay_factor, NUM_CLASSES, augment=True, augment_horizontal_flip=True, augment_offset=2, ) fl.client.start_client(args.server_address, client)
def main() -> None: """Load data, create and start Fashion-MNIST client.""" args = parse_args() client_setting = get_client_setting(args.setting, args.cid) # Configure logger configure(identifier=f"client:{client_setting.cid}", host=args.log_host) log(INFO, "Starting client, settings: %s", client_setting) # Load model model = orig_cnn(input_shape=(28, 28, 1), seed=SEED) # Load local data partition ( (xy_train_partitions, xy_test_partitions), _, ) = tf_fashion_mnist_partitioned.load_data( iid_fraction=client_setting.iid_fraction, num_partitions=client_setting.num_clients, ) x_train, y_train = xy_train_partitions[client_setting.partition] x_test, y_test = xy_test_partitions[client_setting.partition] if client_setting.dry_run: x_train = x_train[0:100] y_train = y_train[0:100] x_test = x_test[0:50] y_test = y_test[0:50] # Start client client = VisionClassificationClient( client_setting.cid, model, (x_train, y_train), (x_test, y_test), client_setting.delay_factor, 10, augment=True, augment_horizontal_flip=False, augment_offset=1, ) fl.client.start_client(args.server_address, client)
def main() -> None: """Load data, create and start client.""" args = parse_args() client_setting = get_client_setting(args.setting, args.cid) # Configure logger configure(identifier=f"client:{client_setting.cid}", host=args.log_host) # Load model model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED) # Load local data partition ( (xy_train_partitions, xy_test_partitions), _, ) = tf_hotkey_partitioned.load_data( iid_fraction=client_setting.iid_fraction, num_partitions=client_setting.num_clients, ) (x_train, y_train) = xy_train_partitions[client_setting.partition] (x_test, y_test) = xy_test_partitions[client_setting.partition] if client_setting.dry_run: x_train = x_train[0:100] y_train = y_train[0:100] x_test = x_test[0:50] y_test = y_test[0:50] # Start client client = VisionClassificationClient( client_setting.cid, model, (x_train, y_train), (x_test, y_test), client_setting.delay_factor, 10, normalization_factor=100.0, ) fl.client.start_client(args.server_address, client)
def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server log(INFO, "server_setting: %s", server_setting) # Load evaluation data (_, _), (x_test, y_test) = tf_hotkey_partitioned.load_data(iid_fraction=0.0, num_partitions=1) if server_setting.dry_run: x_test = x_test[0:50] y_test = y_test[0:50] # Load model (for centralized evaluation) model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED) # Strategy eval_fn = get_eval_fn(model=model, num_classes=10, xy_test=(x_test, y_test)) on_fit_config_fn = get_on_fit_config_fn( lr_initial=server_setting.lr_initial, timeout=server_setting.training_round_timeout, partial_updates=server_setting.partial_updates, ) if server_setting.strategy == "fedavg": strategy = fl.server.strategy.FedAvg( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) if server_setting.strategy == "fast-and-slow": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fast-and-slow` strategy") strategy = fl.server.strategy.FastAndSlow( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, importance_sampling=server_setting.importance_sampling, dynamic_timeout=server_setting.dynamic_timeout, dynamic_timeout_percentile=0.9, alternating_timeout=server_setting.alternating_timeout, r_fast=1, r_slow=1, t_fast=math.ceil(0.5 * server_setting.training_round_timeout), t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "fedfs-v0": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fedfs-v0` strategy") t_fast = (math.ceil(0.5 * server_setting.training_round_timeout) if server_setting.training_round_timeout_short is None else server_setting.training_round_timeout_short) strategy = fl.server.strategy.FedFSv0( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, r_fast=1, r_slow=1, t_fast=t_fast, t_slow=server_setting.training_round_timeout, ) if server_setting.strategy == "qffedavg": strategy = fl.server.strategy.QFedAvg( q_param=0.2, qffl_learning_rate=0.1, fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, ) # Run server fl.server.start_server( DEFAULT_SERVER_ADDRESS, config={"num_rounds": server_setting.rounds}, strategy=strategy, )
def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server log(INFO, "server_setting: %s", server_setting) # Load evaluation data (_, _), (x_test, y_test) = tf_cifar_partitioned.load_data( iid_fraction=0.0, num_partitions=1, cifar100=NUM_CLASSES == 100 ) if server_setting.dry_run: x_test = x_test[0:50] y_test = y_test[0:50] # Load model (for centralized evaluation) model = resnet50v2(input_shape=(32, 32, 3), num_classes=NUM_CLASSES, seed=SEED) # Create client_manager client_manager = fl.server.SimpleClientManager() # Strategy eval_fn = get_eval_fn( model=model, num_classes=NUM_CLASSES, xy_test=(x_test, y_test) ) fit_config_fn = get_on_fit_config_fn( lr_initial=server_setting.lr_initial, timeout=server_setting.training_round_timeout, partial_updates=server_setting.partial_updates, ) if server_setting.strategy == "fedavg": strategy = fl.server.strategy.FedAvg( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=fit_config_fn, ) if server_setting.strategy == "fast-and-slow": if server_setting.training_round_timeout is None: raise ValueError( "No `training_round_timeout` set for `fast-and-slow` strategy" ) strategy = fl.server.strategy.FastAndSlow( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=eval_fn, on_fit_config_fn=fit_config_fn, importance_sampling=server_setting.importance_sampling, dynamic_timeout=server_setting.dynamic_timeout, dynamic_timeout_percentile=0.8, alternating_timeout=server_setting.alternating_timeout, r_fast=1, r_slow=1, t_fast=math.ceil(0.5 * server_setting.training_round_timeout), t_slow=server_setting.training_round_timeout, ) # Run server server = fl.server.Server(client_manager=client_manager, strategy=strategy) fl.server.start_server( DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds}, )