示例#1
0
    def fit_round(self, rnd: int) -> Optional[Weights]:
        """Perform a single round of federated averaging."""
        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.configure_fit(
            rnd=rnd, weights=self.weights, client_manager=self._client_manager
        )
        log(
            DEBUG,
            "fit_round: strategy sampled %s clients (out of %s)",
            len(client_instructions),
            self._client_manager.num_available(),
        )
        if not client_instructions:
            log(INFO, "fit_round: no clients sampled, cancel fit")
            return None

        # Collect training results from all clients participating in this round
        results, failures = fit_clients(client_instructions)
        log(
            DEBUG,
            "fit_round received %s results and %s failures",
            len(results),
            len(failures),
        )

        # Aggregate training results
        return self.strategy.aggregate_fit(rnd, results, failures)
示例#2
0
    def start(self) -> None:
        """Start the instance."""
        instance_groups = group_instances_by_specs(self.instances)

        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [
                executor.submit(create_instances, self.adapter, instance_group,
                                self.timeout)
                for instance_group in instance_groups
            ]
            concurrent.futures.wait(futures)

            try:
                for future in futures:
                    future.result()
            # pylint: disable=broad-except
            except Exception as exc:
                log(
                    ERROR,
                    "Failed to start the cluster completely. Shutting down...",
                )
                log(ERROR, exc)

                for future in futures:
                    future.cancel()

                self.terminate()
                raise StartFailed()

        for ins in self.instances:
            log(DEBUG, ins)
示例#3
0
    def on_configure_fit(
            self, rnd: int, weights: Weights,
            client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Block until `min_num_clients` are available
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available())
        success = client_manager.wait_for(num_clients=min_num_clients,
                                          timeout=WAIT_TIMEOUT)
        if not success:
            # Do not continue if not enough clients are available
            log(
                INFO,
                "FedFS: not enough clients available after timeout %s",
                WAIT_TIMEOUT,
            )
            return []

        # Sample clients
        clients = self._contribution_based_sampling(
            sample_size=sample_size, client_manager=client_manager)

        # Prepare parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Use custom fit config function if provided
            config = self.on_fit_config_fn(rnd)

        # Set timeout for this round
        use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow)
        config["timeout"] = str(
            self.t_fast if use_fast_timeout else self.t_slow)

        # Fit instructions
        fit_ins = FitIns(parameters, config)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]
示例#4
0
    def evaluate(
        self, rnd: int
    ) -> Optional[Tuple[Optional[float], EvaluateResultsAndFailures]]:
        """Validate current global model on a number of clients."""
        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.configure_evaluate(
            rnd=rnd, weights=self.weights, client_manager=self._client_manager
        )
        if not client_instructions:
            log(INFO, "evaluate: no clients sampled, cancel federated evaluation")
            return None
        log(
            DEBUG,
            "evaluate: strategy sampled %s clients",
            len(client_instructions),
        )

        # Evaluate current global weights on those clients
        results_and_failures = evaluate_clients(client_instructions)
        results, failures = results_and_failures
        log(
            DEBUG,
            "evaluate received %s results and %s failures",
            len(results),
            len(failures),
        )
        # Aggregate the evaluation results
        loss_aggregated = self.strategy.aggregate_evaluate(rnd, results, failures)
        return loss_aggregated, results_and_failures
示例#5
0
    def upload_all(self, local_path: str,
                   remote_path: str) -> Dict[str, SFTPAttributes]:
        """Upload file to all instances."""
        results: Dict[str, SFTPAttributes] = {}

        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            # Start the load operations and mark each future with its URL
            future_to_result = {
                executor.submit(self.upload, instance_name, local_path,
                                remote_path): instance_name
                for instance_name in self.get_instance_names()
            }

            for future in concurrent.futures.as_completed(future_to_result):
                instance_name = future_to_result[future]
                try:
                    results[instance_name] = future.result()
                # pylint: disable=broad-except
                except Exception as exc:
                    log(ERROR, (instance_name, exc))

        return results
示例#6
0
文件: client.py 项目: luan-gu/flower
    def evaluate(self, ins: fl.common.EvaluateIns) -> fl.common.EvaluateRes:
        weights = fl.common.parameters_to_weights(ins.parameters)
        config = ins.config
        log(
            DEBUG,
            "evaluate on %s (examples: %s), config %s",
            self.cid,
            self.num_examples_test,
            config,
        )

        # Use provided weights to update the local model
        self.model.set_weights(weights)

        # Evaluate the updated model on the local dataset
        loss, acc = keras_evaluate(self.model,
                                   self.ds_test,
                                   batch_size=self.num_examples_test)

        # Return the number of evaluation examples and the evaluation result (loss)
        return fl.common.EvaluateRes(num_examples=self.num_examples_test,
                                     loss=loss,
                                     accuracy=acc)
示例#7
0
def start_client(
    server_address: str,
    client: Client,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
) -> None:
    """Start a Flower Client which connects to a gRPC server.

    Arguments:
        server_address: str. The IPv6 address of the server. If the Flower
            server runs on the same machine on port 8080, then `server_address`
            would be `"[::]:8080"`.
        client: flwr.client.Client. An implementation of the abstract base
            class `flwr.client.Client`.
        grpc_max_message_length: int (default: 536_870_912, this equals 512MB).
            The maximum length of gRPC messages that can be exchanged with the
            Flower server. The default should be sufficient for most models.
            Users who train very large models might need to increase this
            value. Note that the Flower server needs to be started with the
            same value (see `flwr.server.start_server`), otherwise it will not
            know about the increased limit and block larger messages.

    Returns:
        None.
    """
    while True:
        sleep_duration: int = 0
        with insecure_grpc_connection(
            server_address, max_message_length=grpc_max_message_length
        ) as conn:
            receive, send = conn
            log(INFO, "Opened (insecure) gRPC connection")

            while True:
                server_message = receive()
                client_message, sleep_duration, keep_going = handle(
                    client, server_message
                )
                send(client_message)
                if not keep_going:
                    break
        if sleep_duration == 0:
            log(INFO, "Disconnect and shut down")
            break
        # Sleep and reconnect afterwards
        log(
            INFO,
            "Disconnect, then re-establish connection after %s second(s)",
            sleep_duration,
        )
        time.sleep(sleep_duration)
示例#8
0
    def _get_initial_parameters(self) -> Parameters:
        """Get initial parameters from one of the available clients."""

        # Server-side parameter initialization
        parameters: Optional[Parameters] = self.strategy.initialize_parameters(
            client_manager=self._client_manager)
        if parameters is not None:
            log(INFO, "Using initial parameters provided by strategy")
            return parameters

        # Get initial parameters from one of the clients
        log(INFO, "Requesting initial parameters from one random client")
        random_client = self._client_manager.sample(1)[0]
        parameters_res = random_client.get_parameters()
        log(INFO, "Received initial parameters from one random client")
        return parameters_res.parameters
示例#9
0
def run(baseline: str, setting: str, adapter: str) -> None:
    """Run baseline."""
    print(f"Starting baseline with {setting} settings.")

    wheel_remote_path = (f"/root/{WHEEL_FILENAME}" if adapter == "docker" else
                         f"/home/ubuntu/{WHEEL_FILENAME}")

    settings = load_baseline_setting(baseline, setting)

    # Get instances and add a logserver to the list
    instances = settings.instances
    instances.append(
        Instance(name="logserver", group="logserver", num_cpu=2, num_ram=2))

    # Configure cluster
    log(INFO, "(1/9) Configure cluster.")
    cluster = configure_cluster(adapter, instances, baseline, setting)

    # Start the cluster; this takes some time
    log(INFO, "(2/9) Start cluster.")
    cluster.start()

    # Upload wheel to all instances
    log(INFO, "(3/9) Upload wheel to all instances.")
    cluster.upload_all(WHEEL_LOCAL_PATH, wheel_remote_path)

    # Install the wheel on all instances
    log(INFO, "(4/9) Install wheel on all instances.")
    cluster.exec_all(command.install_wheel(wheel_remote_path))
    extras = ["examples-tensorflow"
              ] if "tf_" in baseline else ["examples-pytorch"]
    cluster.exec_all(
        command.install_wheel(wheel_remote_path=wheel_remote_path,
                              wheel_extras=extras))

    # Download datasets in server and clients
    log(INFO, "(5/9) Download dataset on server and clients.")
    cluster.exec_all(command.download_dataset(baseline=baseline),
                     groups=["server", "clients"])

    # Start logserver
    log(INFO, "(6/9) Start logserver.")
    logserver = cluster.get_instance("logserver")
    cluster.exec(
        logserver.name,
        command.start_logserver(
            logserver_s3_bucket=CONFIG.get("aws", "logserver_s3_bucket"),
            logserver_s3_key=f"{baseline}_{setting}_{now()}.log",
        ),
    )

    # Start Flower server on Flower server instances
    log(INFO, "(7/9) Start server.")
    cluster.exec(
        "server",
        command.start_server(
            log_host=f"{logserver.private_ip}:8081",
            baseline=baseline,
            setting=setting,
        ),
    )

    # Start Flower clients
    log(INFO, "(8/9) Start clients.")
    server = cluster.get_instance("server")

    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        # Start the load operations and mark each future with its URL
        concurrent.futures.wait([
            executor.submit(
                cluster.exec,
                client_setting.instance_name,
                command.start_client(
                    log_host=f"{logserver.private_ip}:8081",
                    server_address=f"{server.private_ip}:8080",
                    baseline=baseline,
                    setting=setting,
                    cid=client_setting.cid,
                ),
            ) for client_setting in settings.clients
        ])

    # Shutdown server and client instance after 10min if not at least one Flower
    # process is running it
    log(INFO, "(9/9) Start shutdown watcher script.")
    cluster.exec_all(command.watch_and_shutdown("flwr", adapter))

    # Give user info how to tail logfile
    private_key = (DOCKER_PRIVATE_KEY if adapter == "docker" else
                   path.expanduser(CONFIG.get("ssh", "private_key")))

    log(
        INFO,
        "If you would like to tail the central logfile run:\n\n\t%s\n",
        command.tail_logfile(adapter, private_key, logserver),
    )
示例#10
0
def start_server(  # pylint: disable=too-many-arguments
    server_address: str = DEFAULT_SERVER_ADDRESS,
    server: Optional[Server] = None,
    config: Optional[Dict[str, int]] = None,
    strategy: Optional[Strategy] = None,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    force_final_distributed_eval: bool = False,
    certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
) -> History:
    """Start a Flower server using the gRPC transport layer.

    Arguments
    ---------
        server_address: Optional[str] (default: `"[::]:8080"`). The IPv6
            address of the server.
        server: Optional[flwr.server.Server] (default: None). An implementation
            of the abstract base class `flwr.server.Server`. If no instance is
            provided, then `start_server` will create one.
        config: Optional[Dict[str, int]] (default: None). The only currently
            supported values is `num_rounds`, so a full configuration object
            instructing the server to perform three rounds of federated
            learning looks like the following: `{"num_rounds": 3}`.
        strategy: Optional[flwr.server.Strategy] (default: None). An
            implementation of the abstract base class `flwr.server.Strategy`.
            If no strategy is provided, then `start_server` will use
            `flwr.server.strategy.FedAvg`.
        grpc_max_message_length: int (default: 536_870_912, this equals 512MB).
            The maximum length of gRPC messages that can be exchanged with the
            Flower clients. The default should be sufficient for most models.
            Users who train very large models might need to increase this
            value. Note that the Flower clients need to be started with the
            same value (see `flwr.client.start_client`), otherwise clients will
            not know about the increased limit and block larger messages.
        force_final_distributed_eval: bool (default: False).
            Forces a distributed evaluation to occur after the last training
            epoch when enabled.
        certificates : Tuple[bytes, bytes, bytes] (default: None)
            Tuple containing root certificate, server certificate, and private key to
            start a secure SSL-enabled server. The tuple is expected to have three bytes
            elements in the following order:

                * CA certificate.
                * server certificate.
                * server private key.

    Returns
    -------
        hist: flwr.server.history.History. Object containing metrics from training.

    Examples
    --------
    Starting an insecure server:

    >>> start_server()

    Starting a SSL-enabled server:

    >>> start_server(
    >>>     certificates=(
    >>>         Path("/crts/root.pem").read_bytes(),
    >>>         Path("/crts/localhost.crt").read_bytes(),
    >>>         Path("/crts/localhost.key").read_bytes()
    >>>     )
    >>> )
    """
    initialized_server, initialized_config = _init_defaults(
        server, config, strategy)

    # Start gRPC server
    grpc_server = start_grpc_server(
        client_manager=initialized_server.client_manager(),
        server_address=server_address,
        max_message_length=grpc_max_message_length,
        certificates=certificates,
    )
    num_rounds = initialized_config["num_rounds"]
    ssl_status = "enabled" if certificates is not None else "disabled"
    msg = f"Flower server running ({num_rounds} rounds)\nSSL is {ssl_status}"
    log(INFO, msg)

    hist = _fl(
        server=initialized_server,
        config=initialized_config,
        force_final_distributed_eval=force_final_distributed_eval,
    )

    # Stop the gRPC server
    grpc_server.stop(grace=1)

    return hist
示例#11
0
文件: common.py 项目: zliel/flower
def custom_fit(
    model: tf.keras.Model,
    dataset: tf.data.Dataset,
    num_epochs: int,
    batch_size: int,
    callbacks: List[tf.keras.callbacks.Callback],
    delay_factor: float = 0.0,
    timeout: Optional[int] = None,
) -> Tuple[bool, float, int]:
    """Train the model using a custom training loop."""
    ds_train = dataset.batch(batch_size=batch_size, drop_remainder=False)

    # Keep results for plotting
    train_loss_results = []
    train_accuracy_results = []

    # Optimizer
    optimizer = tf.keras.optimizers.Adam()

    fit_begin = timeit.default_timer()
    num_examples = 0
    for epoch in range(num_epochs):
        log(INFO, "Starting epoch %s", epoch)

        epoch_loss_avg = tf.keras.metrics.Mean()
        epoch_accuracy = tf.keras.metrics.CategoricalAccuracy()

        # Single loop over the dataset
        batch_begin = timeit.default_timer()
        num_examples_batch = 0
        for batch, (x, y) in enumerate(ds_train):
            num_examples_batch += len(x)

            # Optimize the model
            loss_value, grads = grad(model, x, y)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # Track progress
            epoch_loss_avg.update_state(
                loss_value)  # Add the current batch loss
            epoch_accuracy.update_state(y, model(x, training=True))

            # Track the number of examples used for training
            num_examples += x.shape[0]

            # Delay
            batch_duration = timeit.default_timer() - batch_begin
            if delay_factor > 0.0:
                time.sleep(batch_duration * delay_factor)

            # Progress log
            if batch % 100 == 0:
                log(
                    INFO,
                    "Batch %s: loss %s (%s examples processed, batch duration: %s)",
                    batch,
                    loss_value,
                    num_examples_batch,
                    batch_duration,
                )

            # Timeout
            if timeout is not None:
                fit_duration = timeit.default_timer() - fit_begin
                if fit_duration > timeout:
                    log(INFO, "client timeout")
                    return (False, fit_duration, num_examples)
            batch_begin = timeit.default_timer()

    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())
    log(
        INFO,
        "Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(
            epoch, epoch_loss_avg.result(), epoch_accuracy.result()),
    )

    fit_duration = timeit.default_timer() - fit_begin
    return True, fit_duration, num_examples
示例#12
0
    def fit(self, num_rounds: int) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()
        # Initialize weights by asking one client to return theirs
        self.weights = self._get_initial_weights()
        res = self.strategy.evaluate(weights=self.weights)
        if res is not None:
            log(
                INFO,
                "initial weights (loss/accuracy): %s, %s",
                res[0],
                res[1],
            )
            history.add_loss_centralized(rnd=0, loss=res[0])
            history.add_accuracy_centralized(rnd=0, acc=res[1])

        # Run federated learning for num_rounds
        log(INFO, "[TIME] FL starting")
        start_time = timeit.default_timer()

        for current_round in range(1, num_rounds + 1):
            # Train model and replace previous global model
            weights_prime = self.fit_round(rnd=current_round)
            if weights_prime is not None:
                self.weights = weights_prime

            # Evaluate model using strategy implementation
            res_cen = self.strategy.evaluate(weights=self.weights)
            if res_cen is not None:
                loss_cen, acc_cen = res_cen
                log(
                    INFO,
                    "fit progress: (%s, %s, %s, %s)",
                    current_round,
                    loss_cen,
                    acc_cen,
                    timeit.default_timer() - start_time,
                )
                history.add_loss_centralized(rnd=current_round, loss=loss_cen)
                history.add_accuracy_centralized(rnd=current_round,
                                                 acc=acc_cen)

            # Evaluate model on a sample of available clients
            res_fed = self.evaluate(rnd=current_round)
            if res_fed is not None and res_fed[0] is not None:
                loss_fed, _ = res_fed
                history.add_loss_distributed(rnd=current_round,
                                             loss=cast(float, loss_fed))

            # Conclude round
            loss = res_cen[0] if res_cen is not None else None
            acc = res_cen[1] if res_cen is not None else None
            should_continue = self.strategy.on_conclude_round(
                current_round, loss, acc)
            if not should_continue:
                break

        # Send shutdown signal to all clients
        all_clients = self._client_manager.all()
        _ = shutdown(clients=[all_clients[k] for k in all_clients.keys()])

        # Bookkeeping
        end_time = timeit.default_timer()
        elapsed = end_time - start_time
        log(INFO, "[TIME] FL finished in %s", elapsed)
        return history
示例#13
0
文件: app.py 项目: vballoli/flower
def start_server(
    server_address: str = DEFAULT_SERVER_ADDRESS,
    server: Optional[Server] = None,
    config: Optional[Dict[str, int]] = None,
    strategy: Optional[Strategy] = None,
) -> None:
    """Start a Flower server using the gRPC transport layer."""

    # Create server instance if none was given
    if server is None:
        client_manager = SimpleClientManager()
        if strategy is None:
            strategy = FedAvg()
        server = Server(client_manager=client_manager, strategy=strategy)

    # Set default config values
    if config is None:
        config = {}
    if "num_rounds" not in config:
        config["num_rounds"] = 1

    # Start gRPC server
    grpc_server = start_insecure_grpc_server(
        client_manager=server.client_manager(), server_address=server_address)
    log(INFO, "Flower server running (insecure, %s rounds)",
        config["num_rounds"])

    # Fit model
    hist = server.fit(num_rounds=config["num_rounds"])
    log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
    log(INFO, "app_fit: accuracies_distributed %s",
        str(hist.accuracies_distributed))
    log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
    log(INFO, "app_fit: accuracies_centralized %s",
        str(hist.accuracies_centralized))

    # Temporary workaround to force distributed evaluation
    server.strategy.eval_fn = None  # type: ignore

    # Evaluate the final trained model
    res = server.evaluate(rnd=-1)
    if res is not None:
        loss, (results, failures) = res
        log(INFO, "app_evaluate: federated loss: %s", str(loss))
        log(
            INFO,
            "app_evaluate: results %s",
            str([(res[0].cid, res[1]) for res in results]),
        )
        log(INFO, "app_evaluate: failures %s", str(failures))
    else:
        log(INFO, "app_evaluate: no evaluation result")

    # Stop the gRPC server
    grpc_server.stop(1)
示例#14
0
    def fit(self, num_rounds: int) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()
        # Initialize weights by asking one client to return theirs
        log(INFO, "Getting initial parameters")
        self.weights = self._get_initial_weights()
        log(INFO, "Evaluating initial parameters")
        res = self.strategy.evaluate(weights=self.weights)
        if res is not None:
            log(
                INFO,
                "initial weights (loss/accuracy): %s, %s",
                res[0],
                res[1],
            )
            history.add_loss_centralized(rnd=0, loss=res[0])
            history.add_accuracy_centralized(rnd=0, acc=res[1])

        # Run federated learning for num_rounds
        log(INFO, "[TIME] FL starting")
        start_time = timeit.default_timer()

        for current_round in range(1, num_rounds + 1):
            # Train model and replace previous global model
            weights_prime = self.fit_round(rnd=current_round)
            if weights_prime is not None:
                self.weights = weights_prime

            # Evaluate model using strategy implementation
            res_cen = self.strategy.evaluate(weights=self.weights)
            if res_cen is not None:
                loss_cen, acc_cen = res_cen
                log(
                    INFO,
                    "fit progress: (%s, %s, %s, %s)",
                    current_round,
                    loss_cen,
                    acc_cen,
                    timeit.default_timer() - start_time,
                )
                history.add_loss_centralized(rnd=current_round, loss=loss_cen)
                history.add_accuracy_centralized(rnd=current_round, acc=acc_cen)

            # Evaluate model on a sample of available clients
            res_fed = self.evaluate(rnd=current_round)
            if res_fed is not None and res_fed[0] is not None:
                loss_fed, _ = res_fed
                history.add_loss_distributed(
                    rnd=current_round, loss=cast(float, loss_fed)
                )

        # Bookkeeping
        end_time = timeit.default_timer()
        elapsed = end_time - start_time
        log(INFO, "[TIME] FL finished in %s", elapsed)
        return history
示例#15
0
def on_channel_state_change(channel_connectivity: str) -> None:
    """Log channel connectivity."""
    log(DEBUG, channel_connectivity)
示例#16
0
    def configure_fit(
            self, rnd: int, weights: Weights,
            client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Block until `min_num_clients` are available
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available())
        success = client_manager.wait_for(num_clients=min_num_clients,
                                          timeout=WAIT_TIMEOUT)
        if not success:
            # Do not continue if not enough clients are available
            log(
                INFO,
                "FedFS: not enough clients available after timeout %s",
                WAIT_TIMEOUT,
            )
            return []

        # Sample clients
        msg = "FedFS round %s, sample %s clients (based on all previous contributions)"
        if self.alternating_timeout:
            log(
                DEBUG,
                msg,
                str(rnd),
                str(sample_size),
            )
            clients = self._contribution_based_sampling(
                sample_size=sample_size, client_manager=client_manager)
        elif self.importance_sampling:
            if rnd == 1:
                # Sample with 1/k in the first round
                log(
                    DEBUG,
                    "FedFS round %s, sample %s clients with 1/k",
                    str(rnd),
                    str(sample_size),
                )
                clients = self._one_over_k_sampling(
                    sample_size=sample_size, client_manager=client_manager)
            else:
                fast_round = is_fast_round(rnd - 1,
                                           r_fast=self.r_fast,
                                           r_slow=self.r_slow)
                log(
                    DEBUG,
                    "FedFS round %s, sample %s clients, fast_round %s",
                    str(rnd),
                    str(sample_size),
                    str(fast_round),
                )
                clients = self._fs_based_sampling(
                    sample_size=sample_size,
                    client_manager=client_manager,
                    fast_round=fast_round,
                )
        else:
            clients = self._one_over_k_sampling(sample_size=sample_size,
                                                client_manager=client_manager)

        # Prepare parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Use custom fit config function if provided
            config = self.on_fit_config_fn(rnd)

        # Set timeout for this round
        if self.dynamic_timeout:
            if self.durations:
                candidates = timeout_candidates(
                    durations=self.durations,
                    max_timeout=self.t_slow,
                )
                timeout = next_timeout(
                    candidates=candidates,
                    percentile=self.dynamic_timeout_percentile,
                )
                config["timeout"] = str(timeout)
            else:
                # Initial round has not past durations, use max_timeout
                config["timeout"] = str(self.t_slow)
        elif self.alternating_timeout:
            use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow)
            config["timeout"] = str(
                self.t_fast if use_fast_timeout else self.t_slow)
        else:
            config["timeout"] = str(self.t_slow)

        # Fit instructions
        fit_ins = FitIns(parameters, config)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]
示例#17
0
    def __init__(
        self,
        fraction_fit: float = 0.1,
        fraction_eval: float = 0.1,
        min_fit_clients: int = 2,
        min_eval_clients: int = 2,
        min_available_clients: int = 2,
        eval_fn: Optional[
            Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        dummy_model = None,
        quantize_bits = 64,
    ) -> None:
        """Federated Averaging strategy.

        Implementation based on https://arxiv.org/abs/1602.05629

        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 0.1.
        fraction_eval : float, optional
            Fraction of clients used during validation. Defaults to 0.1.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_eval_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        eval_fn : Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]]
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : Parameters, optional
            Initial global model parameters.
        """
        super().__init__()

        if (
            min_fit_clients > min_available_clients
            or min_eval_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        self.fraction_fit = fraction_fit
        self.fraction_eval = fraction_eval
        self.min_fit_clients = min_fit_clients
        self.min_eval_clients = min_eval_clients
        self.min_available_clients = min_available_clients
        self.eval_fn = eval_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        # a dummy model used to determine dimensions of weights vector if quantization is used
        self.dummy_model = dummy_model
        self.q_bits = quantize_bits
示例#18
0
def start_simulation(  # pylint: disable=too-many-arguments
    *,
    client_fn: Callable[[str], Client],
    num_clients: Optional[int] = None,
    clients_ids: Optional[List[str]] = None,
    client_resources: Optional[Dict[str, int]] = None,
    num_rounds: int = 1,
    strategy: Optional[Strategy] = None,
    ray_init_args: Optional[Dict[str, Any]] = None,
) -> History:
    """Start a Ray-based Flower simulation server.

    Parameters
    ----------
    client_fn : Callable[[str], Client]
        A function creating client instances. The function must take a single
        str argument called `cid`. It should return a single client instance.
        Note that the created client instances are ephemeral and will often be
        destroyed after a single method invocation. Since client instances are
        not long-lived, they should not attempt to carry state over method
        invocations. Any state required by the instance (model, dataset,
        hyperparameters, ...) should be (re-)created in either the call to
        `client_fn` or the call to any of the client methods (e.g., load
        evaluation data in the `evaluate` method itself).
    num_clients : Optional[int]
        The total number of clients in this simulation. This must be set if
        `clients_ids` is not set and vice-versa.
    clients_ids : Optional[List[str]]
        List `client_id`s for each client. This is only required if
        `num_clients` is not set. Setting both `num_clients` and `clients_ids`
        with `len(clients_ids)` not equal to `num_clients` generates an error.
    client_resources : Optional[Dict[str, int]] (default: None)
        CPU and GPU resources for a single client. Supported keys are
        `num_cpus` and `num_gpus`. Example: `{"num_cpus": 4, "num_gpus": 1}`.
        To understand the GPU utilization caused by `num_gpus`, consult the Ray
        documentation on GPU support.
    num_rounds : int (default: 1)
        The number of rounds to train.
    strategy : Optional[flwr.server.Strategy] (default: None)
        An implementation of the abstract base class `flwr.server.Strategy`. If
        no strategy is provided, then `start_server` will use
        `flwr.server.strategy.FedAvg`.
    ray_init_args : Optional[Dict[str, Any]] (default: None)
        Optional dictionary containing arguments for the call to `ray.init`.
        If ray_init_args is None (the default), Ray will be initialized with
        the following default args:

            {
                "ignore_reinit_error": True,
                "include_dashboard": False,
            }

        An empty dictionary can be used (ray_init_args={}) to prevent any
        arguments from being passed to ray.init.

    Returns:
        hist: flwr.server.history.History. Object containing metrics from training.
    """
    cids: List[str]

    # clients_ids takes precedence
    if clients_ids is not None:
        if (num_clients is not None) and (len(clients_ids) != num_clients):
            log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
            sys.exit()
        else:
            cids = clients_ids
    else:
        if num_clients is None:
            log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
            sys.exit()
        else:
            cids = [str(x) for x in range(num_clients)]

    # Default arguments for Ray initialization
    if not ray_init_args:
        ray_init_args = {
            "ignore_reinit_error": True,
            "include_dashboard": False,
        }

    # Shut down Ray if it has already been initialized
    if ray.is_initialized():
        ray.shutdown()

    # Initialize Ray
    ray.init(**ray_init_args)
    log(
        INFO,
        "Ray initialized with resources: %s",
        ray.cluster_resources(),
    )

    # Initialize server and server config
    config = {"num_rounds": num_rounds}
    initialized_server, initialized_config = _init_defaults(
        None, config, strategy)
    log(
        INFO,
        "Starting Flower simulation running: %s",
        initialized_config,
    )

    # Register one RayClientProxy object for each client with the ClientManager
    resources = client_resources if client_resources is not None else {}
    for cid in cids:
        client_proxy = RayClientProxy(
            client_fn=client_fn,
            cid=cid,
            resources=resources,
        )
        initialized_server.client_manager().register(client=client_proxy)

    # Start training
    hist = _fl(
        server=initialized_server,
        config=initialized_config,
        force_final_distributed_eval=False,
    )

    return hist
示例#19
0
def main() -> None:
    """Download data."""
    log(INFO, "Download Keyword Detection")
    tf_hotkey_partitioned.hotkey_load()
示例#20
0
def start_client(
    server_address: str,
    client: Client,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    root_certificates: Optional[bytes] = None,
) -> None:
    """Start a Flower Client which connects to a gRPC server.

    Parameters
    ----------
        server_address: str. The IPv6 address of the server. If the Flower
            server runs on the same machine on port 8080, then `server_address`
            would be `"[::]:8080"`.
        client: flwr.client.Client. An implementation of the abstract base
            class `flwr.client.Client`.
        grpc_max_message_length: int (default: 536_870_912, this equals 512MB).
            The maximum length of gRPC messages that can be exchanged with the
            Flower server. The default should be sufficient for most models.
            Users who train very large models might need to increase this
            value. Note that the Flower server needs to be started with the
            same value (see `flwr.server.start_server`), otherwise it will not
            know about the increased limit and block larger messages.
        root_certificates: bytes (default: None)
            The PEM-encoded root certificates as a byte string. If provided, a secure
            connection using the certificates will be established to a
            SSL-enabled Flower server.

    Returns
    -------
        None

    Examples
    --------
    Starting a client with insecure server connection:

    >>> start_client(
    >>>     server_address=localhost:8080,
    >>>     client=FlowerClient(),
    >>> )

    Starting a SSL-enabled client:

    >>> from pathlib import Path
    >>> start_client(
    >>>     server_address=localhost:8080,
    >>>     client=FlowerClient(),
    >>>     root_certificates=Path("/crts/root.pem").read_bytes(),
    >>> )
    """
    while True:
        sleep_duration: int = 0
        with grpc_connection(
                server_address,
                max_message_length=grpc_max_message_length,
                root_certificates=root_certificates,
        ) as conn:
            receive, send = conn
            log(INFO, "Opened (insecure) gRPC connection")

            while True:
                server_message = receive()
                client_message, sleep_duration, keep_going = handle(
                    client, server_message)
                send(client_message)
                if not keep_going:
                    break
        if sleep_duration == 0:
            log(INFO, "Disconnect and shut down")
            break
        # Sleep and reconnect afterwards
        log(
            INFO,
            "Disconnect, then re-establish connection after %s second(s)",
            sleep_duration,
        )
        time.sleep(sleep_duration)
示例#21
0
def grpc_connection(
    server_address: str,
    max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    root_certificates: Optional[bytes] = None,
) -> Iterator[Tuple[Callable[[], ServerMessage], Callable[[ClientMessage],
                                                          None]]]:
    """Establish an insecure gRPC connection to a gRPC server.

    Parameters
    ----------
    server_address : str
        The IPv6 address of the server. If the Flower server runs on the same machine
        on port 8080, then `server_address` would be `"[::]:8080"`.
    grpc_max_message_length : int
        The maximum length of gRPC messages that can be exchanged with the Flower
        server. The default should be sufficient for most models. Users who train
        very large models might need to increase this value. Note that the Flower
        server needs to be started with the same value
        (see `flwr.server.start_server`), otherwise it will not know about the
        increased limit and block larger messages.
        (default: 536_870_912, this equals 512MB)
    root_certificates : Optional[bytes] (default: None)
        The PEM-encoded root certificates as a byte string. If provided, a secure
        connection using the certificates will be established to a SSL-enabled
        Flower server.

    Returns
    -------
    receive, send : Callable, Callable

    Examples
    --------
    Establishing a SSL-enabled connection to the server:

    >>> from pathlib import Path
    >>> with grpc_connection(
    >>>     server_address,
    >>>     max_message_length=grpc_max_message_length,
    >>>     root_certificates=Path("/crts/root.pem").read_bytes(),
    >>> ) as conn:
    >>>     receive, send = conn
    >>>     server_message = receive()
    >>>     # do something here
    >>>     send(client_message)
    """
    channel_options = [
        ("grpc.max_send_message_length", max_message_length),
        ("grpc.max_receive_message_length", max_message_length),
    ]

    if root_certificates is not None:
        ssl_channel_credentials = grpc.ssl_channel_credentials(
            root_certificates)
        channel = grpc.secure_channel(server_address,
                                      ssl_channel_credentials,
                                      options=channel_options)
    else:
        channel = grpc.insecure_channel(server_address,
                                        options=channel_options)

    channel.subscribe(on_channel_state_change)

    queue: Queue[ClientMessage] = Queue(  # pylint: disable=unsubscriptable-object
        maxsize=1)
    stub = FlowerServiceStub(channel)

    server_message_iterator: Iterator[ServerMessage] = stub.Join(
        iter(queue.get, None))

    receive: Callable[[],
                      ServerMessage] = lambda: next(server_message_iterator)
    send: Callable[[ClientMessage],
                   None] = lambda msg: queue.put(msg, block=False)

    try:
        yield (receive, send)
    finally:
        # Make sure to have a final
        channel.close()
        log(DEBUG, "Insecure gRPC channel closed")
示例#22
0
def main() -> None:
    """Download data."""
    log(INFO, "Download Fashion-MNIST")
    tf.keras.datasets.fashion_mnist.load_data()
示例#23
0
def _fl(server: Server, config: Dict[str, int]) -> None:
    # Fit model
    hist = server.fit(num_rounds=config["num_rounds"])
    log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
    log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
    log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
    log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))

    # Temporary workaround to force distributed evaluation
    server.strategy.eval_fn = None  # type: ignore

    # Evaluate the final trained model
    res = server.evaluate_round(rnd=-1)
    if res is not None:
        loss, _, (results, failures) = res
        log(INFO, "app_evaluate: federated loss: %s", str(loss))
        log(
            INFO,
            "app_evaluate: results %s",
            str([(res[0].cid, res[1]) for res in results]),
        )
        log(INFO, "app_evaluate: failures %s", str(failures))
    else:
        log(INFO, "app_evaluate: no evaluation result")

    # Graceful shutdown
    server.disconnect_all_clients()
示例#24
0
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server
    log(INFO, "server_setting: %s", server_setting)

    # Load evaluation data
    (_, _), (x_test,
             y_test) = tf_hotkey_partitioned.load_data(iid_fraction=0.0,
                                                       num_partitions=1)
    if server_setting.dry_run:
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Load model (for centralized evaluation)
    model = keyword_cnn(input_shape=(80, 40, 1), seed=SEED)

    # Strategy
    eval_fn = get_eval_fn(model=model,
                          num_classes=10,
                          xy_test=(x_test, y_test))
    on_fit_config_fn = get_on_fit_config_fn(
        lr_initial=server_setting.lr_initial,
        timeout=server_setting.training_round_timeout,
        partial_updates=server_setting.partial_updates,
    )

    if server_setting.strategy == "fedavg":
        strategy = fl.server.strategy.FedAvg(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    if server_setting.strategy == "fast-and-slow":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fast-and-slow` strategy")
        strategy = fl.server.strategy.FastAndSlow(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            importance_sampling=server_setting.importance_sampling,
            dynamic_timeout=server_setting.dynamic_timeout,
            dynamic_timeout_percentile=0.9,
            alternating_timeout=server_setting.alternating_timeout,
            r_fast=1,
            r_slow=1,
            t_fast=math.ceil(0.5 * server_setting.training_round_timeout),
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "fedfs-v0":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fedfs-v0` strategy")
        t_fast = (math.ceil(0.5 * server_setting.training_round_timeout)
                  if server_setting.training_round_timeout_short is None else
                  server_setting.training_round_timeout_short)
        strategy = fl.server.strategy.FedFSv0(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            r_fast=1,
            r_slow=1,
            t_fast=t_fast,
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "qffedavg":
        strategy = fl.server.strategy.QFedAvg(
            q_param=0.2,
            qffl_learning_rate=0.1,
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    # Run server
    fl.server.start_server(
        DEFAULT_SERVER_ADDRESS,
        config={"num_rounds": server_setting.rounds},
        strategy=strategy,
    )
示例#25
0
    def fit(self, num_rounds: int) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()

        # Initialize parameters
        log(INFO, "Initializing global parameters")
        self.parameters = self._get_initial_parameters()
        log(INFO, "Evaluating initial parameters")
        res = self.strategy.evaluate(parameters=self.parameters)
        if res is not None:
            log(
                INFO,
                "initial parameters (loss, other metrics): %s, %s",
                res[0],
                res[1],
            )
            history.add_loss_centralized(rnd=0, loss=res[0])
            history.add_metrics_centralized(rnd=0, metrics=res[1])

        # Run federated learning for num_rounds
        log(INFO, "FL starting")
        start_time = timeit.default_timer()

        for current_round in range(1, num_rounds + 1):
            # Train model and replace previous global model
            res_fit = self.fit_round(rnd=current_round)
            if res_fit:
                parameters_prime, _, _ = res_fit  # fit_metrics_aggregated
                if parameters_prime:
                    self.parameters = parameters_prime

            # Evaluate model using strategy implementation
            res_cen = self.strategy.evaluate(parameters=self.parameters)
            if res_cen is not None:
                loss_cen, metrics_cen = res_cen
                log(
                    INFO,
                    "fit progress: (%s, %s, %s, %s)",
                    current_round,
                    loss_cen,
                    metrics_cen,
                    timeit.default_timer() - start_time,
                )
                history.add_loss_centralized(rnd=current_round, loss=loss_cen)
                history.add_metrics_centralized(rnd=current_round,
                                                metrics=metrics_cen)

            # Evaluate model on a sample of available clients
            res_fed = self.evaluate_round(rnd=current_round)
            if res_fed:
                loss_fed, evaluate_metrics_fed, _ = res_fed
                if loss_fed:
                    history.add_loss_distributed(rnd=current_round,
                                                 loss=loss_fed)
                    history.add_metrics_distributed(
                        rnd=current_round, metrics=evaluate_metrics_fed)

        # Bookkeeping
        end_time = timeit.default_timer()
        elapsed = end_time - start_time
        log(INFO, "FL finished in %s", elapsed)
        return history
示例#26
0
def start_server(  # pylint: disable=too-many-arguments
    server_address: str = DEFAULT_SERVER_ADDRESS,
    server: Optional[Server] = None,
    config: Optional[Dict[str, int]] = None,
    strategy: Optional[Strategy] = None,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    force_final_distributed_eval: bool = False,
) -> None:
    """Start a Flower server using the gRPC transport layer.

    Arguments:
        server_address: Optional[str] (default: `"[::]:8080"`). The IPv6
            address of the server.
        server: Optional[flwr.server.Server] (default: None). An implementation
            of the abstract base class `flwr.server.Server`. If no instance is
            provided, then `start_server` will create one.
        config: Optional[Dict[str, int]] (default: None). The only currently
            supported values is `num_rounds`, so a full configuration object
            instructing the server to perform three rounds of federated
            learning looks like the following: `{"num_rounds": 3}`.
        strategy: Optional[flwr.server.Strategy] (default: None). An
            implementation of the abstract base class `flwr.server.Strategy`.
            If no strategy is provided, then `start_server` will use
            `flwr.server.strategy.FedAvg`.
        grpc_max_message_length: int (default: 536_870_912, this equals 512MB).
            The maximum length of gRPC messages that can be exchanged with the
            Flower clients. The default should be sufficient for most models.
            Users who train very large models might need to increase this
            value. Note that the Flower clients need to be started with the
            same value (see `flwr.client.start_client`), otherwise clients will
            not know about the increased limit and block larger messages.
        force_final_distributed_eval: bool (default: False).
            Forces a distributed evaulation to occur after the last training
            epoch when enabled.

    Returns:
        None.
    """
    initialized_server, initialized_config = _init_defaults(
        server, config, strategy)

    # Start gRPC server
    grpc_server = start_insecure_grpc_server(
        client_manager=initialized_server.client_manager(),
        server_address=server_address,
        max_message_length=grpc_max_message_length,
    )
    log(
        INFO,
        "Flower server running (insecure, %s rounds)",
        initialized_config["num_rounds"],
    )

    _fl(
        server=initialized_server,
        config=initialized_config,
        force_final_distributed_eval=force_final_distributed_eval,
    )

    # Stop the gRPC server
    grpc_server.stop(grace=1)
示例#27
0
        partial_updates: bool) -> Callable[[int], Dict[str, fl.common.Scalar]]:
    """Return a function which returns training configurations."""
    def fit_config(rnd: int) -> Dict[str, fl.common.Scalar]:
        """Return a configuration with static batch size and (local) epochs."""
        config: Dict[str, fl.common.Scalar] = {
            "epoch_global": str(rnd),
            "epochs": str(5),
            "batch_size": str(32),
            "lr_initial": str(lr_initial),
            "lr_decay": str(0.99),
            "partial_updates": "1" if partial_updates else "0",
        }
        if timeout is not None:
            config["timeout"] = str(timeout)

        return config

    return fit_config


if __name__ == "__main__":
    # pylint: disable=broad-except
    try:
        main()
    except Exception as err:
        log(ERROR, "Fatal error in main")
        log(ERROR, err, exc_info=True, stack_info=True)

        # Raise the error again so the exit code is correct
        raise err
示例#28
0
def start_server(
    server_address: str = DEFAULT_SERVER_ADDRESS,
    server: Optional[Server] = None,
    config: Optional[Dict[str, int]] = None,
    strategy: Optional[Strategy] = None,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
) -> None:
    """Start a Flower server using the gRPC transport layer.

    Arguments:
        server_address: Optional[str] (default: `"[::]:8080"`). The IPv6
            address of the server.
        server: Optional[flwr.server.Server] (default: None). An implementation
            of the abstract base class `flwr.server.Server`. If no instance is
            provided, then `start_server` will create one.
        config: Optional[Dict[str, int]] (default: None). The only currently
            supported values is `num_rounds`, so a full configuration object
            instructing the server to perform three rounds of federated
            learning looks like the following: `{"num_rounds": 3}`.
        strategy: Optional[flwr.server.Strategy] (default: None). An
            implementation of the abstract base class `flwr.server.Strategy`.
            If no strategy is provided, then `start_server` will use
            `flwr.server.strategy.FedAvg`.
        grpc_max_message_length: int (default: 536_870_912, this equals 512MB).
            The maximum length of gRPC messages that can be exchanged with the
            Flower clients. The default should be sufficient for most models.
            Users who train very large models might need to increase this
            value. Note that the Flower clients need to be started with the
            same value (see `flwr.client.start_client`), otherwise clients will
            not know about the increased limit and block larger messages.

    Returns:
        None.
    """

    # Create server instance if none was given
    if server is None:
        client_manager = SimpleClientManager()
        if strategy is None:
            strategy = FedAvg()
        server = Server(client_manager=client_manager, strategy=strategy)

    # Set default config values
    if config is None:
        config = {}
    if "num_rounds" not in config:
        config["num_rounds"] = 1

    # Start gRPC server
    grpc_server = start_insecure_grpc_server(
        client_manager=server.client_manager(),
        server_address=server_address,
        max_message_length=grpc_max_message_length,
    )
    log(INFO, "Flower server running (insecure, %s rounds)",
        config["num_rounds"])

    # Fit model
    hist = server.fit(num_rounds=config["num_rounds"])
    log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
    log(INFO, "app_fit: accuracies_distributed %s",
        str(hist.accuracies_distributed))
    log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
    log(INFO, "app_fit: accuracies_centralized %s",
        str(hist.accuracies_centralized))

    # Temporary workaround to force distributed evaluation
    server.strategy.eval_fn = None  # type: ignore

    # Evaluate the final trained model
    res = server.evaluate(rnd=-1)
    if res is not None:
        loss, (results, failures) = res
        log(INFO, "app_evaluate: federated loss: %s", str(loss))
        log(
            INFO,
            "app_evaluate: results %s",
            str([(res[0].cid, res[1]) for res in results]),
        )
        log(INFO, "app_evaluate: failures %s", str(failures))
    else:
        log(INFO, "app_evaluate: no evaluation result")

    # Graceful shutdown
    server.disconnect_all_clients()

    # Stop the gRPC server
    grpc_server.stop(1)
示例#29
0
    def create_instances(
        self,
        num_cpu: int,
        num_ram: float,
        timeout: int,
        num_instance: int = 1,
        gpu: bool = False,
    ) -> List[AdapterInstance]:
        """Create one or more EC2 instance(s) of the same type.

        Args:
            num_cpu (int): Number of instance vCPU (values in
                            ec2_adapter.INSTANCE_TYPES_CPU or INSTANCE_TYPES_GPU)
            num_ram (int): RAM in GB (values in ec2_adapter.INSTANCE_TYPES_CPU
                            or INSTANCE_TYPES_GPU)
            timeout (int): Timeout in minutes
            num_instance (int): Number of instances to start if currently available in EC2
        """
        # The instance will be set to terminate after stutdown
        # This is a fail safe in case something happens and the instances
        # are not correctly shutdown
        user_data = ["#!/bin/bash", f"sudo shutdown -P {timeout}"]
        user_data_str = "\n".join(user_data)

        instance_type, hourly_price = find_instance_type(
            num_cpu, num_ram,
            INSTANCE_TYPES_GPU if gpu else INSTANCE_TYPES_CPU)

        hourly_price_total = round(num_instance * hourly_price, 2)

        log(
            INFO,
            "Starting %s instances of type %s which in total will roughly cost $%s an hour.",
            num_instance,
            instance_type,
            hourly_price_total,
        )

        result: EC2RunInstancesResult = self.ec2.run_instances(
            BlockDeviceMappings=[{
                "DeviceName": "/dev/sda1",
                "Ebs": {
                    "DeleteOnTermination": True
                }
            }],
            ImageId=self.image_id,
            # We always want an exact number of instances
            MinCount=num_instance,
            MaxCount=num_instance,
            InstanceType=instance_type,
            KeyName=self.key_name,
            IamInstanceProfile={"Name": "FlowerInstanceProfile"},
            SubnetId=self.subnet_id,
            SecurityGroupIds=self.security_group_ids,
            TagSpecifications=self.tag_specifications,
            InstanceInitiatedShutdownBehavior="terminate",
            UserData=user_data_str,
        )

        instance_ids = [ins["InstanceId"] for ins in result["Instances"]]

        # As soon as all instances status is "running" we have to check the InstanceStatus which
        # reports impaired functionality that stems from issues internal to the instance, such as
        # impaired reachability
        try:
            self._wait_until_instances_are_reachable(instance_ids=instance_ids)
        except EC2StatusTimeout:
            self.terminate_instances(instance_ids)
            raise EC2CreateInstanceFailure()

        return self.list_instances(instance_ids=instance_ids)
示例#30
0
文件: server.py 项目: vballoli/flower
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server
    log(INFO, "server_setting: %s", server_setting)

    # Load evaluation data
    (_, _), (x_test, y_test) = tf_cifar_partitioned.load_data(
        iid_fraction=0.0, num_partitions=1, cifar100=NUM_CLASSES == 100
    )
    if server_setting.dry_run:
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Load model (for centralized evaluation)
    model = resnet50v2(input_shape=(32, 32, 3), num_classes=NUM_CLASSES, seed=SEED)

    # Create client_manager
    client_manager = fl.server.SimpleClientManager()

    # Strategy
    eval_fn = get_eval_fn(
        model=model, num_classes=NUM_CLASSES, xy_test=(x_test, y_test)
    )
    fit_config_fn = get_on_fit_config_fn(
        lr_initial=server_setting.lr_initial,
        timeout=server_setting.training_round_timeout,
        partial_updates=server_setting.partial_updates,
    )

    if server_setting.strategy == "fedavg":
        strategy = fl.server.strategy.FedAvg(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=fit_config_fn,
        )

    if server_setting.strategy == "fast-and-slow":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fast-and-slow` strategy"
            )
        strategy = fl.server.strategy.FastAndSlow(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=fit_config_fn,
            importance_sampling=server_setting.importance_sampling,
            dynamic_timeout=server_setting.dynamic_timeout,
            dynamic_timeout_percentile=0.8,
            alternating_timeout=server_setting.alternating_timeout,
            r_fast=1,
            r_slow=1,
            t_fast=math.ceil(0.5 * server_setting.training_round_timeout),
            t_slow=server_setting.training_round_timeout,
        )

    # Run server
    server = fl.server.Server(client_manager=client_manager, strategy=strategy)
    fl.server.start_server(
        DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds},
    )