Exemplo n.º 1
0
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server

    # Load evaluation data
    xy_partitions, xy_test = tf_cifar_partitioned.load_data(
        iid_fraction=0.0, num_partitions=1, cifar100=False
    )
    _, xy_test = load_partition(
        xy_partitions,
        xy_test,
        partition=0,
        num_clients=1,
        seed=SEED,
        dry_run=server_setting.dry_run,
    )

    # Load model (for centralized evaluation)
    model = resnet50v2(input_shape=(32, 32, 3), num_classes=10, seed=SEED)

    # Create client_manager, strategy, and server
    client_manager = flwr.SimpleClientManager()
    strategy = flwr.strategy.DefaultStrategy(
        fraction_fit=server_setting.sample_fraction,
        min_fit_clients=server_setting.min_sample_size,
        min_available_clients=server_setting.min_num_clients,
        eval_fn=get_eval_fn(model=model, num_classes=10, xy_test=xy_test),
        on_fit_config_fn=get_on_fit_config_fn(
            server_setting.lr_initial, server_setting.training_round_timeout
        ),
    )
    # strategy = flwr.strategy.FastAndSlow(
    #     fraction_fit=args.sample_fraction,
    #     min_fit_clients=args.min_sample_size,
    #     min_available_clients=args.min_num_clients,
    #     eval_fn=get_eval_fn(model=model, num_classes=10, xy_test=xy_test),
    #     on_fit_config_fn=get_on_fit_config_fn(
    #         args.lr_initial, args.training_round_timeout
    #     ),
    #     r_fast=1,
    #     r_slow=1,
    #     t_fast=20,
    #     t_slow=40,
    # )

    server = flwr.Server(client_manager=client_manager, strategy=strategy)

    # Run server
    flwr.app.start_server(
        DEFAULT_GRPC_SERVER_ADDRESS,
        DEFAULT_GRPC_SERVER_PORT,
        server,
        config={"num_rounds": server_setting.rounds},
    )
Exemplo n.º 2
0
def main() -> None:
    """Start server and train five rounds."""
    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--grpc_server_address",
        type=str,
        default=DEFAULT_GRPC_SERVER_ADDRESS,
        help="gRPC server address (default: [::])",
    )
    parser.add_argument(
        "--grpc_server_port",
        type=int,
        default=DEFAULT_GRPC_SERVER_PORT,
        help="gRPC server port (default: 8080)",
    )
    parser.add_argument(
        "--rounds",
        type=int,
        default=1,
        help="Number of rounds of federated learning (default: 1)",
    )
    parser.add_argument(
        "--sample_fraction",
        type=float,
        default=0.1,
        help=
        "Fraction of available clients used for fit/evaluate (default: 0.1)",
    )
    parser.add_argument(
        "--min_sample_size",
        type=int,
        default=1,
        help="Minimum number of clients used for fit/evaluate (default: 1)",
    )
    parser.add_argument(
        "--min_num_clients",
        type=int,
        default=1,
        help=
        "Minimum number of available clients required for sampling (default: 1)",
    )
    parser.add_argument("--cid", type=str, help="Client CID (no default)")
    args = parser.parse_args()

    # Load evaluation data
    _, xy_test = fashion_mnist.load_data(partition=0, num_partitions=1)

    # Create client_manager, strategy, and server
    client_manager = fl.SimpleClientManager()
    strategy = fl.strategy.DefaultStrategy(
        fraction_fit=args.sample_fraction,
        min_fit_clients=args.min_sample_size,
        min_available_clients=args.min_num_clients,
        eval_fn=get_eval_fn(xy_test=xy_test),
        on_fit_config_fn=fit_config,
    )
    server = fl.Server(client_manager=client_manager, strategy=strategy)

    # Run server
    fl.app.start_server(
        args.grpc_server_address,
        args.grpc_server_port,
        server,
        config={"num_rounds": args.rounds},
    )