def main() -> None: """Start server and train a number of rounds.""" args = parse_args() # Configure logger configure(identifier="server", host=args.log_host) server_setting = get_setting(args.setting).server # Load evaluation data xy_partitions, xy_test = tf_cifar_partitioned.load_data( iid_fraction=0.0, num_partitions=1, cifar100=False ) _, xy_test = load_partition( xy_partitions, xy_test, partition=0, num_clients=1, seed=SEED, dry_run=server_setting.dry_run, ) # Load model (for centralized evaluation) model = resnet50v2(input_shape=(32, 32, 3), num_classes=10, seed=SEED) # Create client_manager, strategy, and server client_manager = flwr.SimpleClientManager() strategy = flwr.strategy.DefaultStrategy( fraction_fit=server_setting.sample_fraction, min_fit_clients=server_setting.min_sample_size, min_available_clients=server_setting.min_num_clients, eval_fn=get_eval_fn(model=model, num_classes=10, xy_test=xy_test), on_fit_config_fn=get_on_fit_config_fn( server_setting.lr_initial, server_setting.training_round_timeout ), ) # strategy = flwr.strategy.FastAndSlow( # fraction_fit=args.sample_fraction, # min_fit_clients=args.min_sample_size, # min_available_clients=args.min_num_clients, # eval_fn=get_eval_fn(model=model, num_classes=10, xy_test=xy_test), # on_fit_config_fn=get_on_fit_config_fn( # args.lr_initial, args.training_round_timeout # ), # r_fast=1, # r_slow=1, # t_fast=20, # t_slow=40, # ) server = flwr.Server(client_manager=client_manager, strategy=strategy) # Run server flwr.app.start_server( DEFAULT_GRPC_SERVER_ADDRESS, DEFAULT_GRPC_SERVER_PORT, server, config={"num_rounds": server_setting.rounds}, )
def main() -> None: """Start server and train five rounds.""" parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--grpc_server_address", type=str, default=DEFAULT_GRPC_SERVER_ADDRESS, help="gRPC server address (default: [::])", ) parser.add_argument( "--grpc_server_port", type=int, default=DEFAULT_GRPC_SERVER_PORT, help="gRPC server port (default: 8080)", ) parser.add_argument( "--rounds", type=int, default=1, help="Number of rounds of federated learning (default: 1)", ) parser.add_argument( "--sample_fraction", type=float, default=0.1, help= "Fraction of available clients used for fit/evaluate (default: 0.1)", ) parser.add_argument( "--min_sample_size", type=int, default=1, help="Minimum number of clients used for fit/evaluate (default: 1)", ) parser.add_argument( "--min_num_clients", type=int, default=1, help= "Minimum number of available clients required for sampling (default: 1)", ) parser.add_argument("--cid", type=str, help="Client CID (no default)") args = parser.parse_args() # Load evaluation data _, xy_test = fashion_mnist.load_data(partition=0, num_partitions=1) # Create client_manager, strategy, and server client_manager = fl.SimpleClientManager() strategy = fl.strategy.DefaultStrategy( fraction_fit=args.sample_fraction, min_fit_clients=args.min_sample_size, min_available_clients=args.min_num_clients, eval_fn=get_eval_fn(xy_test=xy_test), on_fit_config_fn=fit_config, ) server = fl.Server(client_manager=client_manager, strategy=strategy) # Run server fl.app.start_server( args.grpc_server_address, args.grpc_server_port, server, config={"num_rounds": args.rounds}, )