Ejemplo n.º 1
0
def insecure_grpc_connection(
    server_address: str,
) -> Iterator[Tuple[Callable[[], ServerMessage], Callable[[ClientMessage],
                                                          None]]]:
    """Establish an insecure gRPC connection to a gRPC server."""
    channel = grpc.insecure_channel(
        server_address,
        options=[
            ("grpc.max_send_message_length", 256 * 1024 * 1024),
            ("grpc.max_receive_message_length", 256 * 1024 * 1024),
        ],
    )
    channel.subscribe(on_channel_state_change)

    queue: Queue[ClientMessage] = Queue(  # pylint: disable=unsubscriptable-object
        maxsize=1)
    stub = FlowerServiceStub(channel)  # type: ignore

    server_message_iterator: Iterator[ServerMessage] = stub.Join(
        iter(queue.get, None))

    receive: Callable[[],
                      ServerMessage] = lambda: next(server_message_iterator)
    send: Callable[[ClientMessage],
                   None] = lambda msg: queue.put(msg, block=False)

    try:
        yield (receive, send)
    finally:
        # Make sure to have a final
        channel.close()
        log(DEBUG, "Insecure gRPC channel closed")
Ejemplo n.º 2
0
def start_client(server_address: str, client: Client) -> None:
    """Start a Flower Client which connects to a gRPC server."""
    with insecure_grpc_connection(server_address) as conn:
        receive, send = conn
        log(INFO, "Opened (insecure) gRPC connection")

        while True:
            server_message = receive()
            client_message = handle(client, server_message)
            send(client_message)
Ejemplo n.º 3
0
    def exec(self, instance_name: str, command: str) -> ExecInfo:
        """Run command on instance and return stdout."""
        log(DEBUG, "Exec on %s: %s", instance_name, command)

        instance = self.get_instance(instance_name)

        with ssh_connection(instance, self.ssh_credentials) as client:
            _, stdout, stderr = client.exec_command(command)
            stdout = stdout.readlines()
            stderr = stderr.readlines()

        return stdout, stderr
Ejemplo n.º 4
0
    def _fs_based_sampling(
        self, sample_size: int, client_manager: ClientManager, fast_round: bool
    ) -> List[ClientProxy]:
        """Sample clients with 1/k * c/m in fast rounds and 1 - c/m in slow rounds."""
        all_clients: Dict[str, ClientProxy] = client_manager.all()
        k = len(all_clients)
        cid_idx: Dict[int, str] = {}
        raw: List[float] = []
        for idx, (cid, _) in enumerate(all_clients.items()):
            cid_idx[idx] = cid

            if cid in self.contributions.keys():
                # Previously selected clients
                contribs: List[Tuple[int, int, int]] = self.contributions[cid]

                # pylint: disable-msg=invalid-name
                if self.use_past_contributions:
                    cs = [c for _, c, _ in contribs]
                    ms = [m for _, _, m in contribs]
                    c_over_m = sum(cs) / sum(ms)
                else:
                    _, c, m = contribs[-1]
                    c_over_m = c / m
                # pylint: enable-msg=invalid-name

                if fast_round:
                    importance = (1 / k) * c_over_m + E
                else:
                    importance = 1 - c_over_m + E
            else:
                # Previously unselected clients
                if fast_round:
                    importance = 1 / k
                else:
                    importance = 1
            raw.append(importance)

        log(
            DEBUG,
            "FedFS _fs_based_sampling, sample %s clients, raw %s",
            str(sample_size),
            str(raw),
        )

        return normalize_and_sample(
            all_clients=all_clients,
            cid_idx=cid_idx,
            raw=np.array(raw),
            sample_size=sample_size,
            use_softmax=False,
        )
Ejemplo n.º 5
0
def main() -> None:
    """Download data."""
    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--cifar",
        type=int,
        choices=[10, 100],
        default=10,
        help="CIFAR version, allowed values: 10 or 100 (default: 10)",
    )
    args = parser.parse_args()
    log(INFO, "Download CIFAR-%s", args.cifar)

    # Load model and data
    download_data(num_classes=args.cifar)
Ejemplo n.º 6
0
    def fit(self, ins: fl.FitIns) -> fl.FitRes:
        weights: fl.Weights = fl.parameters_to_weights(ins[0])
        config = ins[1]
        log(
            DEBUG,
            "fit on %s (examples: %s), config %s",
            self.cid,
            self.num_examples_train,
            config,
        )

        # Training configuration
        # epoch_global = int(config["epoch_global"])
        epochs = int(config["epochs"])
        batch_size = int(config["batch_size"])
        # lr_initial = float(config["lr_initial"])
        # lr_decay = float(config["lr_decay"])
        timeout = int(config["timeout"]) if "timeout" in config else None
        partial_updates = bool(int(config["partial_updates"]))

        # Use provided weights to update the local model
        self.model.set_weights(weights)

        # Train the local model using the local dataset
        completed, fit_duration, num_examples = custom_fit(
            model=self.model,
            dataset=self.ds_train,
            num_epochs=epochs,
            batch_size=batch_size,
            callbacks=[],
            delay_factor=self.delay_factor,
            timeout=timeout,
        )
        log(DEBUG, "client %s had fit_duration %s", self.cid, fit_duration)

        # Compute the maximum number of examples which could have been processed
        num_examples_ceil = self.num_examples_train * epochs

        if not completed and not partial_updates:
            # Return empty update if local update could not be completed in time
            parameters = fl.weights_to_parameters([])
        else:
            # Return the refined weights and the number of examples used for training
            parameters = fl.weights_to_parameters(self.model.get_weights())
        return parameters, num_examples, num_examples_ceil, fit_duration
Ejemplo n.º 7
0
def main() -> None:
    """Load data, create and start CIFAR-10/100 client."""
    args = parse_args()

    client_setting = get_client_setting(args.setting, args.cid)

    # Configure logger
    configure(identifier=f"client:{client_setting.cid}", host=args.log_host)
    log(INFO, "Starting client, settings: %s", client_setting)

    # Load model
    model = resnet50v2(input_shape=(32, 32, 3),
                       num_classes=NUM_CLASSES,
                       seed=SEED)

    # Load local data partition
    (xy_train_partitions,
     xy_test_partitions), _ = tf_cifar_partitioned.load_data(
         iid_fraction=client_setting.iid_fraction,
         num_partitions=client_setting.num_clients,
         cifar100=False,
     )
    x_train, y_train = xy_train_partitions[client_setting.partition]
    x_test, y_test = xy_test_partitions[client_setting.partition]
    if client_setting.dry_run:
        x_train = x_train[0:100]
        y_train = y_train[0:100]
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Start client
    client = VisionClassificationClient(
        client_setting.cid,
        model,
        (x_train, y_train),
        (x_test, y_test),
        client_setting.delay_factor,
        NUM_CLASSES,
        augment=True,
        augment_horizontal_flip=True,
        augment_offset=2,
    )
    fl.app.client.start_client(args.server_address, client)
Ejemplo n.º 8
0
def main() -> None:
    """Load data, create and start Fashion-MNIST client."""
    args = parse_args()

    client_setting = get_client_setting(args.setting, args.cid)

    # Configure logger
    configure(identifier=f"client:{client_setting.cid}", host=args.log_host)
    log(INFO, "Starting client, settings: %s", client_setting)

    # Load model
    model = orig_cnn(input_shape=(28, 28, 1), seed=SEED)

    # Load local data partition
    (
        (xy_train_partitions, xy_test_partitions),
        _,
    ) = tf_fashion_mnist_partitioned.load_data(
        iid_fraction=client_setting.iid_fraction,
        num_partitions=client_setting.num_clients,
    )
    x_train, y_train = xy_train_partitions[client_setting.partition]
    x_test, y_test = xy_test_partitions[client_setting.partition]
    if client_setting.dry_run:
        x_train = x_train[0:100]
        y_train = y_train[0:100]
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Start client
    client = VisionClassificationClient(
        client_setting.cid,
        model,
        (x_train, y_train),
        (x_test, y_test),
        client_setting.delay_factor,
        10,
        augment=True,
        augment_horizontal_flip=False,
        augment_offset=1,
    )
    fl.app.client.start_client(args.server_address, client)
Ejemplo n.º 9
0
    def on_configure_fit(
        self, rnd: int, weights: Weights, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Block until `min_num_clients` are available
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        success = client_manager.wait_for(
            num_clients=min_num_clients, timeout=WAIT_TIMEOUT
        )
        if not success:
            # Do not continue if not enough clients are available
            log(
                INFO,
                "FedFS: not enough clients available after timeout %s",
                WAIT_TIMEOUT,
            )
            return []

        # Sample clients
        clients = self._contribution_based_sampling(
            sample_size=sample_size, client_manager=client_manager
        )

        # Prepare parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Use custom fit config function if provided
            config = self.on_fit_config_fn(rnd)

        # Set timeout for this round
        use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow)
        config["timeout"] = str(self.t_fast if use_fast_timeout else self.t_slow)

        # Fit instructions
        fit_ins = (parameters, config)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]
Ejemplo n.º 10
0
    def evaluate(self, ins: fl.EvaluateIns) -> fl.EvaluateRes:
        weights = fl.parameters_to_weights(ins[0])
        config = ins[1]
        log(
            DEBUG,
            "evaluate on %s (examples: %s), config %s",
            self.cid,
            self.num_examples_test,
            config,
        )

        # Use provided weights to update the local model
        self.model.set_weights(weights)

        # Evaluate the updated model on the local dataset
        loss, acc = keras_evaluate(self.model,
                                   self.ds_test,
                                   batch_size=self.num_examples_test)

        # Return the number of evaluation examples and the evaluation result (loss)
        return self.num_examples_test, loss, acc
Ejemplo n.º 11
0
    def fit_round(self, rnd: int) -> Optional[Weights]:
        """Perform a single round of federated averaging."""
        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.on_configure_fit(
            rnd=rnd, weights=self.weights, client_manager=self._client_manager)
        log(
            DEBUG,
            "fit_round: strategy sampled %s clients",
            len(client_instructions),
        )
        if not client_instructions:
            log(INFO, "fit_round: no clients sampled, cancel fit")
            return None

        # Collect training results from all clients participating in this round
        results, failures = fit_clients(client_instructions)
        log(
            DEBUG,
            "fit_round received %s results and %s failures",
            len(results),
            len(failures),
        )

        # Aggregate training results
        return self.strategy.on_aggregate_fit(rnd, results, failures)
Ejemplo n.º 12
0
    def evaluate(
        self, rnd: int
    ) -> Optional[Tuple[Optional[float], EvaluateResultsAndFailures]]:
        """Validate current global model on a number of clients."""
        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.on_configure_evaluate(
            rnd=rnd, weights=self.weights, client_manager=self._client_manager)
        if not client_instructions:
            log(INFO,
                "evaluate: no clients sampled, cancel federated evaluation")
            return None
        log(
            DEBUG,
            "evaluate: strategy sampled %s clients",
            len(client_instructions),
        )

        # Evaluate current global weights on those clients
        results_and_failures = evaluate_clients(client_instructions)
        results, failures = results_and_failures
        log(
            DEBUG,
            "evaluate received %s results and %s failures",
            len(results),
            len(failures),
        )
        # Aggregate the evaluation results
        loss_aggregated = self.strategy.on_aggregate_evaluate(
            rnd, results, failures)
        return loss_aggregated, results_and_failures
Ejemplo n.º 13
0
    def start(self) -> None:
        """Start the instance."""
        instance_groups = group_instances_by_specs(self.instances)

        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [
                executor.submit(
                    create_instances, self.adapter, instance_group, self.timeout
                )
                for instance_group in instance_groups
            ]
            concurrent.futures.wait(futures)

            try:
                for future in futures:
                    future.result()
            # pylint: disable=broad-except
            except Exception as exc:
                log(
                    ERROR, "Failed to start the cluster completely. Shutting down...",
                )
                log(ERROR, exc)

                for future in futures:
                    future.cancel()

                self.terminate()
                raise StartFailed()

        for ins in self.instances:
            log(DEBUG, ins)
Ejemplo n.º 14
0
    def exec_all(
        self, command: str, groups: Optional[List[str]] = None
    ) -> Dict[str, ExecInfo]:
        """Run command on all instances. If provided filter by group."""
        instance_names = self.get_instance_names(groups)

        results: Dict[str, ExecInfo] = {}

        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            # Start the load operations and mark each future with its URL
            future_to_result = {
                executor.submit(self.exec, instance_name, command): instance_name
                for instance_name in instance_names
            }

            for future in concurrent.futures.as_completed(future_to_result):
                instance_name = future_to_result[future]
                try:
                    results[instance_name] = future.result()
                # pylint: disable=broad-except
                except Exception as exc:
                    log(ERROR, (instance_name, exc))

        return results
Ejemplo n.º 15
0
    def upload_all(
        self, local_path: str, remote_path: str
    ) -> Dict[str, SFTPAttributes]:
        """Upload file to all instances."""
        results: Dict[str, SFTPAttributes] = {}

        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            # Start the load operations and mark each future with its URL
            future_to_result = {
                executor.submit(
                    self.upload, instance_name, local_path, remote_path
                ): instance_name
                for instance_name in self.get_instance_names()
            }

            for future in concurrent.futures.as_completed(future_to_result):
                instance_name = future_to_result[future]
                try:
                    results[instance_name] = future.result()
                # pylint: disable=broad-except
                except Exception as exc:
                    log(ERROR, (instance_name, exc))

        return results
Ejemplo n.º 16
0
    """Return a function which returns training configurations."""

    def fit_config(rnd: int) -> Dict[str, str]:
        """Return a configuration with static batch size and (local) epochs."""
        config = {
            "epoch_global": str(rnd),
            "epochs": str(5),
            "batch_size": str(10),
            "lr_initial": str(lr_initial),
            "lr_decay": str(0.99),
            "partial_updates": "1" if partial_updates else "0",
        }
        if timeout is not None:
            config["timeout"] = str(timeout)

        return config

    return fit_config


if __name__ == "__main__":
    # pylint: disable=broad-except
    try:
        main()
    except Exception as err:
        log(ERROR, "Fatal error in main")
        log(ERROR, err, exc_info=True, stack_info=True)

        # Raise the error again so the exit code is correct
        raise err
Ejemplo n.º 17
0
def main() -> None:
    """Download data."""
    log(INFO, "Download Keyword Detection")
    tf_hotkey_partitioned.hotkey_load()
Ejemplo n.º 18
0
def custom_fit(
    model: tf.keras.Model,
    dataset: tf.data.Dataset,
    num_epochs: int,
    batch_size: int,
    callbacks: List[tf.keras.callbacks.Callback],
    delay_factor: float = 0.0,
    timeout: Optional[int] = None,
) -> Tuple[bool, float, int]:
    """Train the model using a custom training loop."""
    ds_train = dataset.batch(batch_size=batch_size, drop_remainder=False)

    # Keep results for plotting
    train_loss_results = []
    train_accuracy_results = []

    # Optimizer
    optimizer = tf.keras.optimizers.Adam()

    fit_begin = timeit.default_timer()
    num_examples = 0
    for epoch in range(num_epochs):
        log(INFO, "Starting epoch %s", epoch)

        epoch_loss_avg = tf.keras.metrics.Mean()
        epoch_accuracy = tf.keras.metrics.CategoricalAccuracy()

        # Single loop over the dataset
        batch_begin = timeit.default_timer()
        num_examples_batch = 0
        for batch, (x, y) in enumerate(ds_train):
            num_examples_batch += len(x)

            # Optimize the model
            loss_value, grads = grad(model, x, y)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # Track progress
            epoch_loss_avg.update_state(
                loss_value)  # Add the current batch loss
            epoch_accuracy.update_state(y, model(x, training=True))

            # Track the number of examples used for training
            num_examples += x.shape[0]

            # Delay
            batch_duration = timeit.default_timer() - batch_begin
            if delay_factor > 0.0:
                time.sleep(batch_duration * delay_factor)

            # Progress log
            if batch % 100 == 0:
                log(
                    INFO,
                    "Batch %s: loss %s (%s examples processed, batch duration: %s)",
                    batch,
                    loss_value,
                    num_examples_batch,
                    batch_duration,
                )

            # Timeout
            if timeout is not None:
                fit_duration = timeit.default_timer() - fit_begin
                if fit_duration > timeout:
                    log(INFO, "client timeout")
                    return (False, fit_duration, num_examples)
            batch_begin = timeit.default_timer()

    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())
    log(
        INFO,
        "Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(
            epoch, epoch_loss_avg.result(), epoch_accuracy.result()),
    )

    fit_duration = timeit.default_timer() - fit_begin
    return True, fit_duration, num_examples
Ejemplo n.º 19
0
    def fit(self, num_rounds: int) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()
        # Initialize weights by asking one client to return theirs
        self.weights = self._get_initial_weights()
        res = self.strategy.evaluate(weights=self.weights)
        if res is not None:
            log(
                INFO,
                "initial weights (loss/accuracy): %s, %s",
                res[0],
                res[1],
            )
            history.add_loss_centralized(rnd=0, loss=res[0])
            history.add_accuracy_centralized(rnd=0, acc=res[1])

        # Run federated learning for num_rounds
        log(INFO, "[TIME] FL starting")
        start_time = timeit.default_timer()

        for current_round in range(1, num_rounds + 1):
            # Train model and replace previous global model
            weights_prime = self.fit_round(rnd=current_round)
            if weights_prime is not None:
                self.weights = weights_prime

            # Evaluate model using strategy implementation
            res_cen = self.strategy.evaluate(weights=self.weights)
            if res_cen is not None:
                loss_cen, acc_cen = res_cen
                log(
                    INFO,
                    "fit progress: (%s, %s, %s, %s)",
                    current_round,
                    loss_cen,
                    acc_cen,
                    timeit.default_timer() - start_time,
                )
                history.add_loss_centralized(rnd=current_round, loss=loss_cen)
                history.add_accuracy_centralized(rnd=current_round,
                                                 acc=acc_cen)

            # Evaluate model on a sample of available clients
            res_fed = self.evaluate(rnd=current_round)
            if res_fed is not None and res_fed[0] is not None:
                loss_fed, _ = res_fed
                history.add_loss_distributed(rnd=current_round,
                                             loss=cast(float, loss_fed))

            # Conclude round
            loss = res_cen[0] if res_cen is not None else None
            acc = res_cen[1] if res_cen is not None else None
            should_continue = self.strategy.on_conclude_round(
                current_round, loss, acc)
            if not should_continue:
                break

        end_time = timeit.default_timer()
        elapsed = end_time - start_time
        log(INFO, "[TIME] FL finished in %s", elapsed)
        return history
Ejemplo n.º 20
0
    def on_configure_fit(
        self, rnd: int, weights: Weights, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Block until `min_num_clients` are available
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        success = client_manager.wait_for(
            num_clients=min_num_clients, timeout=WAIT_TIMEOUT
        )
        if not success:
            # Do not continue if not enough clients are available
            log(
                INFO,
                "FedFS: not enough clients available after timeout %s",
                WAIT_TIMEOUT,
            )
            return []

        # Sample clients
        if rnd == 1:
            # Sample with 1/k in the first round
            log(
                DEBUG,
                "FedFS round %s, sample %s clients with 1/k",
                str(rnd),
                str(sample_size),
            )
            clients = self._one_over_k_sampling(
                sample_size=sample_size, client_manager=client_manager
            )
        else:
            fast_round = is_fast_round(rnd - 1, r_fast=self.r_fast, r_slow=self.r_slow)
            log(
                DEBUG,
                "FedFS round %s, sample %s clients, fast_round %s",
                str(rnd),
                str(sample_size),
                str(fast_round),
            )
            clients = self._fs_based_sampling(
                sample_size=sample_size,
                client_manager=client_manager,
                fast_round=fast_round,
            )

        # Prepare parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Use custom fit config function if provided
            config = self.on_fit_config_fn(rnd)

        # Set timeout for this round
        if self.durations:
            candidates = timeout_candidates(
                durations=self.durations, max_timeout=self.t_max,
            )
            timeout = next_timeout(
                candidates=candidates, percentile=self.dynamic_timeout_percentile,
            )
            config["timeout"] = str(timeout)
        else:
            # Initial round has not past durations, use max_timeout
            config["timeout"] = str(self.t_max)

        # Fit instructions
        fit_ins = (parameters, config)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]
Ejemplo n.º 21
0
def on_channel_state_change(channel_connectivity: str) -> None:
    """Log channel connectivity."""
    log(DEBUG, channel_connectivity)
Ejemplo n.º 22
0
def main() -> None:
    """Download data."""
    log(INFO, "Download Fashion-MNIST")
    tf.keras.datasets.fashion_mnist.load_data()
Ejemplo n.º 23
0
    def create_instances(
        self,
        num_cpu: int,
        num_ram: float,
        timeout: int,
        num_instance: int = 1,
        gpu: bool = False,
    ) -> List[AdapterInstance]:
        """Create one or more EC2 instance(s) of the same type.

        Args:
            num_cpu (int): Number of instance vCPU (values in
                            ec2_adapter.INSTANCE_TYPES_CPU or INSTANCE_TYPES_GPU)
            num_ram (int): RAM in GB (values in ec2_adapter.INSTANCE_TYPES_CPU
                            or INSTANCE_TYPES_GPU)
            timeout (int): Timeout in minutes
            num_instance (int): Number of instances to start if currently available in EC2

        """
        # The instance will be set to terminate after stutdown
        # This is a fail safe in case something happens and the instances
        # are not correctly shutdown
        user_data = ["#!/bin/bash", f"sudo shutdown -P {timeout}"]
        user_data_str = "\n".join(user_data)

        instance_type, hourly_price = find_instance_type(
            num_cpu, num_ram,
            INSTANCE_TYPES_GPU if gpu else INSTANCE_TYPES_CPU)

        hourly_price_total = round(num_instance * hourly_price, 2)

        log(
            INFO,
            "Starting %s instances of type %s which in total will roughly cost $%s an hour.",
            num_instance,
            instance_type,
            hourly_price_total,
        )

        result: EC2RunInstancesResult = self.ec2.run_instances(
            ImageId=self.image_id,
            # We always want an exact number of instances
            MinCount=num_instance,
            MaxCount=num_instance,
            InstanceType=instance_type,
            KeyName=self.key_name,
            IamInstanceProfile={"Name": "FlowerInstanceProfile"},
            SubnetId=self.subnet_id,
            SecurityGroupIds=self.security_group_ids,
            TagSpecifications=self.tag_specifications,
            InstanceInitiatedShutdownBehavior="terminate",
            UserData=user_data_str,
        )

        instance_ids = [ins["InstanceId"] for ins in result["Instances"]]

        # As soon as all instances status is "running" we have to check the InstanceStatus which
        # reports impaired functionality that stems from issues internal to the instance, such as
        # impaired reachability
        try:
            self._wait_until_instances_are_reachable(instance_ids=instance_ids)
        except EC2StatusTimeout:
            self.terminate_instances(instance_ids)
            raise EC2CreateInstanceFailure()

        return self.list_instances(instance_ids=instance_ids)
Ejemplo n.º 24
0
def start_server(
    server_address: str = DEFAULT_SERVER_ADDRESS,
    server: Optional[Server] = None,
    config: Optional[Dict[str, int]] = None,
    strategy: Optional[Strategy] = None,
) -> None:
    """Start a Flower server using the gRPC transport layer."""

    # Create server instance if none was given
    if server is None:
        client_manager = SimpleClientManager()
        if strategy is None:
            strategy = FedAvg()
        server = Server(client_manager=client_manager, strategy=strategy)

    # Set default config values
    if config is None:
        config = {}
    if "num_rounds" not in config:
        config["num_rounds"] = 1

    # Start gRPC server
    grpc_server = start_insecure_grpc_server(
        client_manager=server.client_manager(), server_address=server_address)
    log(INFO, "Flower server running (insecure, %s rounds)",
        config["num_rounds"])

    # Fit model
    hist = server.fit(num_rounds=config["num_rounds"])
    log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
    log(INFO, "app_fit: accuracies_distributed %s",
        str(hist.accuracies_distributed))
    log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
    log(INFO, "app_fit: accuracies_centralized %s",
        str(hist.accuracies_centralized))

    # Temporary workaround to force distributed evaluation
    server.strategy.eval_fn = None  # type: ignore

    # Evaluate the final trained model
    res = server.evaluate(rnd=-1)
    if res is not None:
        loss, (results, failures) = res
        log(INFO, "app_evaluate: federated loss: %s", str(loss))
        log(
            INFO,
            "app_evaluate: results %s",
            str([(res[0].cid, res[1]) for res in results]),
        )
        log(INFO, "app_evaluate: failures %s", str(failures))
    else:
        log(INFO, "app_evaluate: no evaluation result")

    # Stop the gRPC server
    grpc_server.stop(1)
Ejemplo n.º 25
0
def main() -> None:
    """Start server and train a number of rounds."""
    args = parse_args()

    # Configure logger
    configure(identifier="server", host=args.log_host)

    server_setting = get_setting(args.setting).server
    log(INFO, "server_setting: %s", server_setting)

    # Load evaluation data
    (_, _), (x_test, y_test) = tf_fashion_mnist_partitioned.load_data(
        iid_fraction=0.0, num_partitions=1
    )
    if server_setting.dry_run:
        x_test = x_test[0:50]
        y_test = y_test[0:50]

    # Load model (for centralized evaluation)
    model = orig_cnn(input_shape=(28, 28, 1), seed=SEED)

    # Create client_manager
    client_manager = fl.SimpleClientManager()

    # Strategy
    eval_fn = get_eval_fn(model=model, num_classes=10, xy_test=(x_test, y_test))
    on_fit_config_fn = get_on_fit_config_fn(
        lr_initial=server_setting.lr_initial,
        timeout=server_setting.training_round_timeout,
        partial_updates=server_setting.partial_updates,
    )

    if server_setting.strategy == "fedavg":
        strategy = fl.strategy.FedAvg(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    if server_setting.strategy == "fast-and-slow":
        if server_setting.training_round_timeout is None:
            raise ValueError(
                "No `training_round_timeout` set for `fast-and-slow` strategy"
            )
        t_fast = (
            math.ceil(0.5 * server_setting.training_round_timeout)
            if server_setting.training_round_timeout_short is None
            else server_setting.training_round_timeout_short
        )
        strategy = fl.strategy.FastAndSlow(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            importance_sampling=server_setting.importance_sampling,
            dynamic_timeout=server_setting.dynamic_timeout,
            dynamic_timeout_percentile=0.8,
            alternating_timeout=server_setting.alternating_timeout,
            r_fast=1,
            r_slow=1,
            t_fast=t_fast,
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "fedfs-v0":
        if server_setting.training_round_timeout is None:
            raise ValueError("No `training_round_timeout` set for `fedfs-v0` strategy")
        t_fast = (
            math.ceil(0.5 * server_setting.training_round_timeout)
            if server_setting.training_round_timeout_short is None
            else server_setting.training_round_timeout_short
        )
        strategy = fl.strategy.FedFSv0(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            r_fast=1,
            r_slow=1,
            t_fast=t_fast,
            t_slow=server_setting.training_round_timeout,
        )

    if server_setting.strategy == "fedfs-v1":
        if server_setting.training_round_timeout is None:
            raise ValueError("No `training_round_timeout` set for `fedfs-v1` strategy")
        strategy = fl.strategy.FedFSv1(
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            dynamic_timeout_percentile=0.8,
            r_fast=1,
            r_slow=1,
            t_max=server_setting.training_round_timeout,
            use_past_contributions=True,
        )

    if server_setting.strategy == "qffedavg":
        strategy = fl.strategy.QffedAvg(
            q_param=0.2,
            qffl_learning_rate=0.1,
            fraction_fit=server_setting.sample_fraction,
            min_fit_clients=server_setting.min_sample_size,
            min_available_clients=server_setting.min_num_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
        )

    # Run server
    log(INFO, "Instantiating server, strategy: %s", str(strategy))
    server = fl.Server(client_manager=client_manager, strategy=strategy)
    fl.app.server.start_server(
        DEFAULT_SERVER_ADDRESS, server, config={"num_rounds": server_setting.rounds},
    )
Ejemplo n.º 26
0
def run(baseline: str, setting: str, adapter: str) -> None:
    """Run baseline."""
    print(f"Starting baseline with {setting} settings.")

    wheel_remote_path = (f"/root/{WHEEL_FILENAME}" if adapter == "docker" else
                         f"/home/ubuntu/{WHEEL_FILENAME}")

    if baseline == "tf_cifar":
        settings = tf_cifar_settings.get_setting(setting)
    elif baseline == "tf_fashion_mnist":
        settings = tf_fashion_mnist_settings.get_setting(setting)
    elif baseline == "tf_hotkey":
        settings = tf_hotkey_settings.get_setting(setting)
    else:
        raise Exception("Setting not found.")

    # Get instances and add a logserver to the list
    instances = settings.instances
    instances.append(
        Instance(name="logserver", group="logserver", num_cpu=2, num_ram=2))

    # Configure cluster
    log(INFO, "(1/9) Configure cluster.")
    cluster = configure_cluster(adapter, instances, baseline, setting)

    # Start the cluster; this takes some time
    log(INFO, "(2/9) Start cluster.")
    cluster.start()

    # Upload wheel to all instances
    log(INFO, "(3/9) Upload wheel to all instances.")
    cluster.upload_all(WHEEL_LOCAL_PATH, wheel_remote_path)

    # Install the wheel on all instances
    log(INFO, "(4/9) Install wheel on all instances.")
    cluster.exec_all(command.install_wheel(wheel_remote_path))

    # Download datasets in server and clients
    log(INFO, "(5/9) Download dataset on server and clients.")
    cluster.exec_all(command.download_dataset(baseline=baseline),
                     groups=["server", "clients"])

    # Start logserver
    log(INFO, "(6/9) Start logserver.")
    logserver = cluster.get_instance("logserver")
    cluster.exec(
        logserver.name,
        command.start_logserver(
            logserver_s3_bucket=CONFIG.get("aws", "logserver_s3_bucket"),
            logserver_s3_key=f"{baseline}_{setting}_{now()}.log",
        ),
    )

    # Start Flower server on Flower server instances
    log(INFO, "(7/9) Start server.")
    cluster.exec(
        "server",
        command.start_server(
            log_host=f"{logserver.private_ip}:8081",
            baseline=baseline,
            setting=setting,
        ),
    )

    # Start Flower clients
    log(INFO, "(8/9) Start clients.")
    server = cluster.get_instance("server")

    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        # Start the load operations and mark each future with its URL
        concurrent.futures.wait([
            executor.submit(
                cluster.exec,
                client_setting.instance_name,
                command.start_client(
                    log_host=f"{logserver.private_ip}:8081",
                    server_address=f"{server.private_ip}:8080",
                    baseline=baseline,
                    setting=setting,
                    cid=client_setting.cid,
                ),
            ) for client_setting in settings.clients
        ])

    # Shutdown server and client instance after 10min if not at least one Flower
    # process is running it
    log(INFO, "(9/9) Start shutdown watcher script.")
    cluster.exec_all(command.watch_and_shutdown("flower", adapter))

    # Give user info how to tail logfile
    private_key = (DOCKER_PRIVATE_KEY if adapter == "docker" else
                   path.expanduser(CONFIG.get("ssh", "private_key")))

    log(
        INFO,
        "If you would like to tail the central logfile run:\n\n\t%s\n",
        command.tail_logfile(adapter, private_key, logserver),
    )