示例#1
0
    def on_configure_evaluate(
        self, rnd: int, weights: Weights, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        # Do not configure federated evaluation if a centralized evaluation
        # function is provided
        if self.eval_fn is not None:
            return []

        # Parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_evaluate_config_fn is not None:
            # Custom evaluation config function provided
            config = self.on_evaluate_config_fn(rnd)
        evaluate_ins = (parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size,
                                        min_num_clients=min_num_clients)

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]
示例#2
0
 def _one_over_k_sampling(
         self, sample_size: int,
         client_manager: ClientManager) -> List[ClientProxy]:
     """Sample clients with probability 1/k."""
     sample_size, min_num_clients = self.num_fit_clients(
         client_manager.num_available())
     clients = client_manager.sample(num_clients=sample_size,
                                     min_num_clients=min_num_clients)
     return clients
示例#3
0
    def _contribution_based_sampling(
            self, sample_size: int,
            client_manager: ClientManager) -> List[ClientProxy]:
        """Sample clients depending on their past contributions."""
        # Get all clients and gather their contributions
        all_clients: Dict[str, ClientProxy] = client_manager.all()
        cid_idx: Dict[int, str] = {}
        raw: List[float] = []
        for idx, (cid, _) in enumerate(all_clients.items()):
            cid_idx[idx] = cid
            penalty = 0.0
            if cid in self.contributions.keys():
                contribs: List[Tuple[int, int, int]] = self.contributions[cid]
                penalty = statistics.mean([c / m for _, c, m in contribs])
            # `p` should be:
            # - High for clients which have never been picked before
            # - Medium for clients which have contributed,
            #   but not used their entire budget
            # - Low (but not 0) for clients which have been picked and used their budget
            raw.append(1.1 - penalty)

        # Sample clients
        return normalize_and_sample(
            all_clients=all_clients,
            cid_idx=cid_idx,
            raw=np.array(raw),
            sample_size=sample_size,
            use_softmax=False,
        )
示例#4
0
    def configure_fit(
            self, rnd: int, parameters: Parameters,
            client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(rnd)
        fit_ins = FitIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size,
                                        min_num_clients=min_num_clients)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]
示例#5
0
    def configure_fit(
        self, rnd: int, weights: Weights, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Block until `min_num_clients` are available
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        success = client_manager.wait_for(
            num_clients=min_num_clients, timeout=WAIT_TIMEOUT
        )
        if not success:
            # Do not continue if not enough clients are available
            log(
                INFO,
                "FedFS: not enough clients available after timeout %s",
                WAIT_TIMEOUT,
            )
            return []

        # Sample clients
        clients = self._contribution_based_sampling(
            sample_size=sample_size, client_manager=client_manager
        )

        # Prepare parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Use custom fit config function if provided
            config = self.on_fit_config_fn(rnd)

        # Set timeout for this round
        use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow)
        config["timeout"] = str(self.t_fast if use_fast_timeout else self.t_slow)

        # Fit instructions
        fit_ins = FitIns(parameters, config)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]
示例#6
0
    def _fs_based_sampling(self, sample_size: int,
                           client_manager: ClientManager,
                           fast_round: bool) -> List[ClientProxy]:
        """Sample clients with 1/k * c/m in fast rounds and 1 - c/m in slow rounds."""
        all_clients: Dict[str, ClientProxy] = client_manager.all()
        k = len(all_clients)
        cid_idx: Dict[int, str] = {}
        raw: List[float] = []
        for idx, (cid, _) in enumerate(all_clients.items()):
            cid_idx[idx] = cid

            if cid in self.contributions.keys():
                # Previously selected clients
                contribs: List[Tuple[int, int, int]] = self.contributions[cid]

                # pylint: disable=invalid-name
                if self.use_past_contributions:
                    cs = [c for _, c, _ in contribs]
                    ms = [m for _, _, m in contribs]
                    c_over_m = sum(cs) / sum(ms)
                else:
                    _, c, m = contribs[-1]
                    c_over_m = c / m
                # pylint: enable-msg=invalid-name

                if fast_round:
                    importance = (1 / k) * c_over_m + E
                else:
                    importance = 1 - c_over_m + E
            else:
                # Previously unselected clients
                if fast_round:
                    importance = 1 / k
                else:
                    importance = 1
            raw.append(importance)

        log(
            DEBUG,
            "FedFS _fs_based_sampling, sample %s clients, raw %s",
            str(sample_size),
            str(raw),
        )

        return normalize_and_sample(
            all_clients=all_clients,
            cid_idx=cid_idx,
            raw=np.array(raw),
            sample_size=sample_size,
            use_softmax=False,
        )
示例#7
0
def register_client(
    client_manager: ClientManager,
    client: GrpcClientProxy,
    context: grpc.ServicerContext,
) -> bool:
    """Try registering GrpcClientProxy with ClientManager."""
    is_success = client_manager.register(client)

    if is_success:

        def rpc_termination_callback() -> None:
            client.bridge.close()
            client_manager.unregister(client)

        context.add_callback(rpc_termination_callback)

    return is_success
示例#8
0
    def configure_fit(
            self, rnd: int, weights: Weights,
            client_manager: ClientManager) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Block until `min_num_clients` are available
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available())
        success = client_manager.wait_for(num_clients=min_num_clients,
                                          timeout=WAIT_TIMEOUT)
        if not success:
            # Do not continue if not enough clients are available
            log(
                INFO,
                "FedFS: not enough clients available after timeout %s",
                WAIT_TIMEOUT,
            )
            return []

        # Sample clients
        msg = "FedFS round %s, sample %s clients (based on all previous contributions)"
        if self.alternating_timeout:
            log(
                DEBUG,
                msg,
                str(rnd),
                str(sample_size),
            )
            clients = self._contribution_based_sampling(
                sample_size=sample_size, client_manager=client_manager)
        elif self.importance_sampling:
            if rnd == 1:
                # Sample with 1/k in the first round
                log(
                    DEBUG,
                    "FedFS round %s, sample %s clients with 1/k",
                    str(rnd),
                    str(sample_size),
                )
                clients = self._one_over_k_sampling(
                    sample_size=sample_size, client_manager=client_manager)
            else:
                fast_round = is_fast_round(rnd - 1,
                                           r_fast=self.r_fast,
                                           r_slow=self.r_slow)
                log(
                    DEBUG,
                    "FedFS round %s, sample %s clients, fast_round %s",
                    str(rnd),
                    str(sample_size),
                    str(fast_round),
                )
                clients = self._fs_based_sampling(
                    sample_size=sample_size,
                    client_manager=client_manager,
                    fast_round=fast_round,
                )
        else:
            clients = self._one_over_k_sampling(sample_size=sample_size,
                                                client_manager=client_manager)

        # Prepare parameters and config
        parameters = weights_to_parameters(weights)
        config = {}
        if self.on_fit_config_fn is not None:
            # Use custom fit config function if provided
            config = self.on_fit_config_fn(rnd)

        # Set timeout for this round
        if self.dynamic_timeout:
            if self.durations:
                candidates = timeout_candidates(
                    durations=self.durations,
                    max_timeout=self.t_slow,
                )
                timeout = next_timeout(
                    candidates=candidates,
                    percentile=self.dynamic_timeout_percentile,
                )
                config["timeout"] = str(timeout)
            else:
                # Initial round has not past durations, use max_timeout
                config["timeout"] = str(self.t_slow)
        elif self.alternating_timeout:
            use_fast_timeout = is_fast_round(rnd - 1, self.r_fast, self.r_slow)
            config["timeout"] = str(
                self.t_fast if use_fast_timeout else self.t_slow)
        else:
            config["timeout"] = str(self.t_slow)

        # Fit instructions
        fit_ins = FitIns(parameters, config)

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]