Ejemplo n.º 1
0
def get_client_setting(setting: str, cid: str) -> ClientSetting:
    """Return client setting based on setting name and cid."""
    for client_setting in get_setting(setting).clients:
        if client_setting.cid == cid:
            return client_setting

    raise ClientSettingNotFound()
Ejemplo n.º 2
0
def load_baseline_setting(baseline: str, setting: str) -> Baseline:
    """Return appropriate baseline setting."""
    if baseline == "tf_cifar":
        return tf_cifar_settings.get_setting(setting)
    if baseline == "tf_fashion_mnist":
        return tf_fashion_mnist_settings.get_setting(setting)
    if baseline == "tf_hotkey":
        return tf_hotkey_settings.get_setting(setting)

    raise Exception("Setting not found.")
Ejemplo n.º 3
0
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server
    log(INFO, "server_setting: %s", server_setting)

    # Load evaluation data
    (_, _), (x_test, y_test) = tf_fashion_mnist_partitioned.load_data(
        iid_fraction=0.0, num_partitions=1
    )
    if server_setting.dry_run:
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Load model (for centralized evaluation)
    model = orig_cnn(input_shape=(28, 28, 1), seed=SEED)

    # Create client_manager
    client_manager = fl.SimpleClientManager()

    # Strategy
    eval_fn = get_eval_fn(model=model, num_classes=10, xy_test=(x_test, y_test))
    on_fit_config_fn = get_on_fit_config_fn(
        lr_initial=server_setting.lr_initial,
        timeout=server_setting.training_round_timeout,
        partial_updates=server_setting.partial_updates,
    )

    if server_setting.strategy == "fedavg":
        strategy = fl.strategy.FedAvg(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    if server_setting.strategy == "fast-and-slow":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fast-and-slow` strategy"
            )
        t_fast = (
            math.ceil(0.5 * server_setting.training_round_timeout)
            if server_setting.training_round_timeout_short is None
            else server_setting.training_round_timeout_short
        )
        strategy = fl.strategy.FastAndSlow(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            importance_sampling=server_setting.importance_sampling,
            dynamic_timeout=server_setting.dynamic_timeout,
            dynamic_timeout_percentile=0.8,
            alternating_timeout=server_setting.alternating_timeout,
            r_fast=1,
            r_slow=1,
            t_fast=t_fast,
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "fedfs-v0":
        if server_setting.training_round_timeout is None:
            raise ValueError("No `training_round_timeout` set for `fedfs-v0` strategy")
        t_fast = (
            math.ceil(0.5 * server_setting.training_round_timeout)
            if server_setting.training_round_timeout_short is None
            else server_setting.training_round_timeout_short
        )
        strategy = fl.strategy.FedFSv0(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            r_fast=1,
            r_slow=1,
            t_fast=t_fast,
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "fedfs-v1":
        if server_setting.training_round_timeout is None:
            raise ValueError("No `training_round_timeout` set for `fedfs-v1` strategy")
        strategy = fl.strategy.FedFSv1(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            dynamic_timeout_percentile=0.8,
            r_fast=1,
            r_slow=1,
            t_max=server_setting.training_round_timeout,
            use_past_contributions=True,
        )

    if server_setting.strategy == "qffedavg":
        strategy = fl.strategy.QffedAvg(
            q_param=0.2,
            qffl_learning_rate=0.1,
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    # Run server
    log(INFO, "Instantiating server, strategy: %s", str(strategy))
    server = fl.Server(client_manager=client_manager, strategy=strategy)
    fl.app.server.start_server(
        DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds},
    )
Ejemplo n.º 4
0
def run(baseline: str, setting: str, adapter: str) -> None:
    """Run baseline."""
    print(f"Starting baseline with {setting} settings.")

    wheel_remote_path = (f"/root/{WHEEL_FILENAME}" if adapter == "docker" else
                         f"/home/ubuntu/{WHEEL_FILENAME}")

    if baseline == "tf_cifar":
        settings = tf_cifar_settings.get_setting(setting)
    elif baseline == "tf_fashion_mnist":
        settings = tf_fashion_mnist_settings.get_setting(setting)
    elif baseline == "tf_hotkey":
        settings = tf_hotkey_settings.get_setting(setting)
    else:
        raise Exception("Setting not found.")

    # Get instances and add a logserver to the list
    instances = settings.instances
    instances.append(
        Instance(name="logserver", group="logserver", num_cpu=2, num_ram=2))

    # Configure cluster
    log(INFO, "(1/9) Configure cluster.")
    cluster = configure_cluster(adapter, instances, baseline, setting)

    # Start the cluster; this takes some time
    log(INFO, "(2/9) Start cluster.")
    cluster.start()

    # Upload wheel to all instances
    log(INFO, "(3/9) Upload wheel to all instances.")
    cluster.upload_all(WHEEL_LOCAL_PATH, wheel_remote_path)

    # Install the wheel on all instances
    log(INFO, "(4/9) Install wheel on all instances.")
    cluster.exec_all(command.install_wheel(wheel_remote_path))

    # Download datasets in server and clients
    log(INFO, "(5/9) Download dataset on server and clients.")
    cluster.exec_all(command.download_dataset(baseline=baseline),
                     groups=["server", "clients"])

    # Start logserver
    log(INFO, "(6/9) Start logserver.")
    logserver = cluster.get_instance("logserver")
    cluster.exec(
        logserver.name,
        command.start_logserver(
            logserver_s3_bucket=CONFIG.get("aws", "logserver_s3_bucket"),
            logserver_s3_key=f"{baseline}_{setting}_{now()}.log",
        ),
    )

    # Start Flower server on Flower server instances
    log(INFO, "(7/9) Start server.")
    cluster.exec(
        "server",
        command.start_server(
            log_host=f"{logserver.private_ip}:8081",
            baseline=baseline,
            setting=setting,
        ),
    )

    # Start Flower clients
    log(INFO, "(8/9) Start clients.")
    server = cluster.get_instance("server")

    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        # Start the load operations and mark each future with its URL
        concurrent.futures.wait([
            executor.submit(
                cluster.exec,
                client_setting.instance_name,
                command.start_client(
                    log_host=f"{logserver.private_ip}:8081",
                    server_address=f"{server.private_ip}:8080",
                    baseline=baseline,
                    setting=setting,
                    cid=client_setting.cid,
                ),
            ) for client_setting in settings.clients
        ])

    # Shutdown server and client instance after 10min if not at least one Flower
    # process is running it
    log(INFO, "(9/9) Start shutdown watcher script.")
    cluster.exec_all(command.watch_and_shutdown("flower", adapter))

    # Give user info how to tail logfile
    private_key = (DOCKER_PRIVATE_KEY if adapter == "docker" else
                   path.expanduser(CONFIG.get("ssh", "private_key")))

    log(
        INFO,
        "If you would like to tail the central logfile run:\n\n\t%s\n",
        command.tail_logfile(adapter, private_key, logserver),
    )