Beispiel #1
0
 def select_delta_params_and_weight(element):
     idx, client_state = element
     delta_params = core.tree_multimap(lambda a, b: a - b, state.params,
                                       client_state.params)
     # Use client weight scaled by domain weight weight for weighted average.
     domain_id = domain_ids[idx]
     client_weight = scaled_dw[domain_id] * client_state.num_examples
     return delta_params, client_weight
Beispiel #2
0
    def run_round(self, state: mime_lite.MimeLiteState,
                  client_ids: List[str]) -> mime_lite.MimeLiteState:
        """Runs one round of Mime."""
        # Compute full-batch gradient at server params on train data.
        combined_dataset = core.preprocess_tf_dataset(
            core.create_tf_dataset_for_clients(self.federated_data,
                                               client_ids),
            self._hparams.combined_data_hparams)
        server_grads = mime_lite.compute_gradient(
            stream=combined_dataset.as_numpy_iterator(),
            params=state.params,
            model=self._model,
            rng_seq=self._rng_seq,
        )

        # Train on clients using custom ControlVariateTrainer.
        client_states = core.train_multiple_clients(
            federated_data=self.federated_data,
            client_ids=client_ids,
            client_trainer=self._client_trainer,
            init_client_trainer_state=self._client_trainer.init_state(
                params=state.params,
                opt_state=state.opt_state,
                control_variate=server_grads),
            rng_seq=self._rng_seq,
            client_data_hparams=self._hparams.train_data_hparams)

        # Weighted average of param delta across clients.
        def select_delta_params_and_weight(client_state):
            delta_params = core.tree_multimap(lambda a, b: a - b, state.params,
                                              client_state.params)
            return delta_params, client_state.num_examples

        delta_params_and_weight = map(select_delta_params_and_weight,
                                      client_states)
        delta_params = core.tree_mean(delta_params_and_weight)

        # Server params uses weighted average of client updates, scaled by the
        # server_learning_rate.
        params = core.tree_multimap(
            lambda p, q: p - self._hparams.server_learning_rate * q,
            state.params, delta_params)
        # Update server opt_state using base_optimimzer and server gradient.
        _, opt_state = self._base_optimizer.update_fn(server_grads,
                                                      state.opt_state)
        return mime_lite.MimeLiteState(params, opt_state)
Beispiel #3
0
 def select_delta_params_and_weight(client_state):
     delta_params = core.tree_multimap(lambda a, b: a - b, state.params,
                                       client_state.params)
     return delta_params, client_state.num_examples