Esempio n. 1
0
def main() -> None:
    """Start Flower baseline."""
    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--baseline",
        type=str,
        required=True,
        choices=["tf_cifar", "tf_fashion_mnist", "tf_hotkey"],
        help="Name of baseline name to run.",
    )
    parser.add_argument(
        "--setting",
        type=str,
        required=True,
        choices=list(
            set(
                list(tf_cifar_settings.SETTINGS.keys()) +
                list(tf_fashion_mnist_settings.SETTINGS.keys()) +
                list(tf_hotkey_settings.SETTINGS.keys()))),
        help="Name of setting to run.",
    )
    parser.add_argument(
        "--adapter",
        type=str,
        required=True,
        choices=["docker", "ec2"],
        help="Set adapter to be used.",
    )
    args = parser.parse_args()

    # Configure logger
    configure(f"flower_{args.baseline}_{args.setting}")

    run(baseline=args.baseline, setting=args.setting, adapter=args.adapter)
Esempio n. 2
0
def main() -> None:
    """Start Flower baseline."""
    parser = argparse.ArgumentParser(description="Flower")

    # When adding a new setting make sure to modify the load_baseline_setting function
    possible_baselines = ["tf_cifar", "tf_fashion_mnist", "tf_hotkey"]
    possible_settings = []
    all_settings = [
        list(tf_cifar_settings.SETTINGS.keys()),
        list(tf_fashion_mnist_settings.SETTINGS.keys()),
        list(tf_hotkey_settings.SETTINGS.keys()),
    ]

    # Show only relevant settings based on baseline as choices
    # for --setting parameter
    baseline_arg = [arg for arg in sys.argv if "--baseline" in arg]
    if len(baseline_arg) > 0:
        selected_baseline = baseline_arg[0].split("=")[1]
        idx = possible_baselines.index(selected_baseline)
        possible_settings = all_settings[idx]

    parser.add_argument(
        "--baseline",
        type=str,
        required=True,
        choices=possible_baselines,
        help="Name of baseline name to run.",
    )
    parser.add_argument(
        "--setting",
        type=str,
        required=True,
        choices=possible_settings,
        help="Name of setting to run.",
    )
    parser.add_argument(
        "--adapter",
        type=str,
        required=True,
        choices=["docker", "ec2"],
        help="Set adapter to be used.",
    )
    args = parser.parse_args()

    # Configure logger
    configure(f"flower_{args.baseline}_{args.setting}")

    run(baseline=args.baseline, setting=args.setting, adapter=args.adapter)
Esempio n. 3
0
def main() -> None:
    """Load data, create and start CIFAR-10/100 client."""
    args = parse_args()

    client_setting = get_client_setting(args.setting, args.cid)

    # Configure logger
    configure(identifier=f"client:{client_setting.cid}", host=args.log_host)
    log(INFO, "Starting client, settings: %s", client_setting)

    # Load model
    model = resnet50v2(input_shape=(32, 32, 3),
                       num_classes=NUM_CLASSES,
                       seed=SEED)

    # Load local data partition
    (xy_train_partitions,
     xy_test_partitions), _ = tf_cifar_partitioned.load_data(
         iid_fraction=client_setting.iid_fraction,
         num_partitions=client_setting.num_clients,
         cifar100=False,
     )
    x_train, y_train = xy_train_partitions[client_setting.partition]
    x_test, y_test = xy_test_partitions[client_setting.partition]
    if client_setting.dry_run:
        x_train = x_train[0:100]
        y_train = y_train[0:100]
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Start client
    client = VisionClassificationClient(
        client_setting.cid,
        model,
        (x_train, y_train),
        (x_test, y_test),
        client_setting.delay_factor,
        NUM_CLASSES,
        augment=True,
        augment_horizontal_flip=True,
        augment_offset=2,
    )
    fl.client.start_client(args.server_address, client)
Esempio n. 4
0
def main() -> None:
    """Load data, create and start Fashion-MNIST client."""
    args = parse_args()

    client_setting = get_client_setting(args.setting, args.cid)

    # Configure logger
    configure(identifier=f"client:{client_setting.cid}", host=args.log_host)
    log(INFO, "Starting client, settings: %s", client_setting)

    # Load model
    model = orig_cnn(input_shape=(28, 28, 1), seed=SEED)

    # Load local data partition
    (
        (xy_train_partitions, xy_test_partitions),
        _,
    ) = tf_fashion_mnist_partitioned.load_data(
        iid_fraction=client_setting.iid_fraction,
        num_partitions=client_setting.num_clients,
    )
    x_train, y_train = xy_train_partitions[client_setting.partition]
    x_test, y_test = xy_test_partitions[client_setting.partition]
    if client_setting.dry_run:
        x_train = x_train[0:100]
        y_train = y_train[0:100]
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Start client
    client = VisionClassificationClient(
        client_setting.cid,
        model,
        (x_train, y_train),
        (x_test, y_test),
        client_setting.delay_factor,
        10,
        augment=True,
        augment_horizontal_flip=False,
        augment_offset=1,
    )
    fl.client.start_client(args.server_address, client)
Esempio n. 5
0
def main() -> None:
    """Load data, create and start client."""
    args = parse_args()

    client_setting = get_client_setting(args.setting, args.cid)

    # Configure logger
    configure(identifier=f"client:{client_setting.cid}", host=args.log_host)

    # Load model
    model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED)

    # Load local data partition
    (
        (xy_train_partitions, xy_test_partitions),
        _,
    ) = tf_hotkey_partitioned.load_data(
        iid_fraction=client_setting.iid_fraction,
        num_partitions=client_setting.num_clients,
    )
    (x_train, y_train) = xy_train_partitions[client_setting.partition]
    (x_test, y_test) = xy_test_partitions[client_setting.partition]
    if client_setting.dry_run:
        x_train = x_train[0:100]
        y_train = y_train[0:100]
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Start client
    client = VisionClassificationClient(
        client_setting.cid,
        model,
        (x_train, y_train),
        (x_test, y_test),
        client_setting.delay_factor,
        10,
        normalization_factor=100.0,
    )
    fl.client.start_client(args.server_address, client)
Esempio n. 6
0
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server
    log(INFO, "server_setting: %s", server_setting)

    # Load evaluation data
    (_, _), (x_test,
             y_test) = tf_hotkey_partitioned.load_data(iid_fraction=0.0,
                                                       num_partitions=1)
    if server_setting.dry_run:
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Load model (for centralized evaluation)
    model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED)

    # Strategy
    eval_fn = get_eval_fn(model=model,
                          num_classes=10,
                          xy_test=(x_test, y_test))
    on_fit_config_fn = get_on_fit_config_fn(
        lr_initial=server_setting.lr_initial,
        timeout=server_setting.training_round_timeout,
        partial_updates=server_setting.partial_updates,
    )

    if server_setting.strategy == "fedavg":
        strategy = fl.server.strategy.FedAvg(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    if server_setting.strategy == "fast-and-slow":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fast-and-slow` strategy")
        strategy = fl.server.strategy.FastAndSlow(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            importance_sampling=server_setting.importance_sampling,
            dynamic_timeout=server_setting.dynamic_timeout,
            dynamic_timeout_percentile=0.9,
            alternating_timeout=server_setting.alternating_timeout,
            r_fast=1,
            r_slow=1,
            t_fast=math.ceil(0.5 * server_setting.training_round_timeout),
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "fedfs-v0":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fedfs-v0` strategy")
        t_fast = (math.ceil(0.5 * server_setting.training_round_timeout)
                  if server_setting.training_round_timeout_short is None else
                  server_setting.training_round_timeout_short)
        strategy = fl.server.strategy.FedFSv0(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            r_fast=1,
            r_slow=1,
            t_fast=t_fast,
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "qffedavg":
        strategy = fl.server.strategy.QFedAvg(
            q_param=0.2,
            qffl_learning_rate=0.1,
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    # Run server
    fl.server.start_server(
        DEFAULT_SERVER_ADDRESS,
        config={"num_rounds": server_setting.rounds},
        strategy=strategy,
    )
Esempio n. 7
0
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server
    log(INFO, "server_setting: %s", server_setting)

    # Load evaluation data
    (_, _), (x_test, y_test) = tf_cifar_partitioned.load_data(
        iid_fraction=0.0, num_partitions=1, cifar100=NUM_CLASSES == 100
    )
    if server_setting.dry_run:
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Load model (for centralized evaluation)
    model = resnet50v2(input_shape=(32, 32, 3), num_classes=NUM_CLASSES, seed=SEED)

    # Create client_manager
    client_manager = fl.server.SimpleClientManager()

    # Strategy
    eval_fn = get_eval_fn(
        model=model, num_classes=NUM_CLASSES, xy_test=(x_test, y_test)
    )
    fit_config_fn = get_on_fit_config_fn(
        lr_initial=server_setting.lr_initial,
        timeout=server_setting.training_round_timeout,
        partial_updates=server_setting.partial_updates,
    )

    if server_setting.strategy == "fedavg":
        strategy = fl.server.strategy.FedAvg(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=fit_config_fn,
        )

    if server_setting.strategy == "fast-and-slow":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fast-and-slow` strategy"
            )
        strategy = fl.server.strategy.FastAndSlow(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=fit_config_fn,
            importance_sampling=server_setting.importance_sampling,
            dynamic_timeout=server_setting.dynamic_timeout,
            dynamic_timeout_percentile=0.8,
            alternating_timeout=server_setting.alternating_timeout,
            r_fast=1,
            r_slow=1,
            t_fast=math.ceil(0.5 * server_setting.training_round_timeout),
            t_slow=server_setting.training_round_timeout,
        )

    # Run server
    server = fl.server.Server(client_manager=client_manager, strategy=strategy)
    fl.server.start_server(
        DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds},
    )