Exemplo n.º 1
0
        def param_postprocess_function(delta_param, all_params, c):
            delta_ti = np_nest.apply_to_structure(lambda x: np.divide(x, len(c)), delta_param)

            ti_updates = []
            for client_index in c:
                new_client_params = np_nest.map_structure(np.add, all_params[client_index], delta_ti)
                precisions = new_client_params['w_pres']
                precisions[precisions < 0] = 1e-5
                new_client_params['w_pres'] = precisions
                ti_updates.append(np_nest.map_structure(np.subtract, new_client_params, all_params[client_index]))

            return ti_updates
Exemplo n.º 2
0
    def get_noised_result(self, sample_state, global_state):
        def add_noise(p):
            return p + np.random.normal(size=p.size) * global_state.noise_stddev

        if self._ledger:
            self._ledger.record_sum_query(global_state.l2_norm_clip, global_state.noise_stddev)

        return np_nest.map_structure(add_noise, sample_state), global_state
Exemplo n.º 3
0
    def tick(self):
        if self.should_stop():
            return False

        lambda_old = self.parameters

        logger.debug("Getting Client Updates")
        delta_is = []

        self.current_damping_factor = self.hyperparameters[
            "damping_factor"] * np.exp(
                -self.iterations * self.hyperparameters["damping_decay"])

        clients_can_update = [client.can_update() for client in self.clients]

        if not np.any(clients_can_update): return 0

        clients_updated = 0

        for i, client in enumerate(self.clients):
            if client.can_update():
                logger.debug(f'On client {i + 1} of {len(self.clients)}')
                client.set_hyperparameters(
                    {"damping_factor": self.current_damping_factor})
                delta_is.append(
                    client.get_update(model_parameters=lambda_old,
                                      model_hyperparameters=None,
                                      update_ti=True))
                logger.debug(
                    f'Finished Client {i + 1} of {len(self.clients)}\n\n')
                clients_updated += 1

        logger.debug("Received client updates")
        lambda_new = lambda_old
        for delta_i in delta_is:
            lambda_new = np_nest.map_structure(np.add, *[lambda_new, delta_i])

        self.parameters = lambda_new

        # update the model parameters
        self.model.set_parameters(self.parameters)
        logger.debug(
            f"Iteration {self.iterations} complete.\nNew Parameters:\n {pretty_dump.dump(lambda_new)}\n"
        )
        [
            client.set_metadata({"global_iteration": self.iterations})
            for client in self.clients
        ]

        self.log_update()

        self.iterations += 1

        return clients_updated
Exemplo n.º 4
0
    def preprocess_record(self, params, record):
        """
        Return the scaled record and also the l2 norm (to deduce whether clipping occured or not)

        :param params:
        :param record:
        :return:
        """
        l2_norm_clip = params
        logger.debug(f"Using {l2_norm_clip}")
        l2_norm = np.sqrt(np_nest.reduce_structure(lambda p: np.linalg.norm(p) ** 2,
                                                   np.add,
                                                   record))
        self._record_l2_norm = l2_norm
        if l2_norm < l2_norm_clip:
            return record
        else:
            return np_nest.map_structure(lambda p: np.divide(p, np.abs(l2_norm / l2_norm_clip)), record)
Exemplo n.º 5
0
    def compute_update(self,
                       model_parameters=None,
                       model_hyperparameters=None,
                       update_ti=True):
        logger.debug("Computing Update")
        if not self.can_update():
            logger.warning(
                'Incorrectly tired to update a client tha cant be updated!')
            return np_nest.map_structure(np.zeros_like,
                                         self.model.get_parameters())
        delta_lambda_i_tilde = super().compute_update(model_parameters,
                                                      model_hyperparameters,
                                                      update_ti)

        logger.debug("Computing Privacy Cost")
        formatted_ledger = self.dp_query.ledger.get_formatted_ledger()
        for _, accountant in self.accountants.items():
            accountant.update_privacy(formatted_ledger)

        return delta_lambda_i_tilde
Exemplo n.º 6
0
    def log_update(self):
        super().log_update()

        if 'global_iteration' in list(self.metadata.keys()):
            self.log['global_iteration'].append(
                self.metadata['global_iteration'])

        self.log['times_updated'].append(self.times_updated)

        if self.metadata['log_params']:
            self.log['params'].append(
                np_nest.structured_ndarrays_to_lists(
                    self.model.get_parameters()))
        if self.metadata['log_t_i']:
            self.log['t_i'].append(
                np_nest.structured_ndarrays_to_lists(
                    np_nest.map_structure(np.mean, self.t_i)))
        if self.metadata['log_model_info']:
            self.log['model'].append(
                np_nest.structured_ndarrays_to_lists(
                    self.model.get_incremental_log_record()))
Exemplo n.º 7
0
    def compute_update(self,
                       model_parameters=None,
                       model_hyperparameters=None,
                       update_ti=True):
        super().compute_update(model_parameters, model_hyperparameters)

        parameters_old = self.model.get_parameters()
        t_i = np_nest.map_structure(np.subtract, parameters_old,
                                    self.hyperparameters['prior'])

        self.t_i = t_i

        parameters_new = self.model.fit(self.data, t_i)

        delta_lambda_i = np_utils.subtract_params(parameters_new,
                                                  parameters_old)

        logger.debug(f"Old Params: {parameters_old}\n"
                     f"New Params: {parameters_new}\n")

        self.times_updated += 1

        return delta_lambda_i
Exemplo n.º 8
0
 def merge_sample_states(self, sample_state_1, sample_state_2):
     return np_nest.map_structure(np.add, sample_state_1, sample_state_2)
Exemplo n.º 9
0
 def accumulate_preprocessed_record(self, sample_state, record):
     return np_nest.map_structure(np.add, sample_state, record)
Exemplo n.º 10
0
 def initial_sample_state(self, param_groups):
     """ Return state of zeros the same shape as the parameter groups."""
     return np_nest.map_structure(np.zeros_like, param_groups)
Exemplo n.º 11
0
def run_experiment(ray_cfg, prior_pres, PVI_settings, privacy_settings,
                   optimisation_settings, N_samples, N_iterations, prediction,
                   experiment_tag, logging_base_directory, save_t_is, _run,
                   _config, seed):
    torch.set_num_threads(int(ray_cfg["num_cpus"]))
    np.random.seed(seed)
    torch.manual_seed(seed)

    try:

        training_set, test_set, d_in = load_data()
        clients_data, nis, prop_positive, M = generate_dataset_distribution_func(
        )(training_set["x"], training_set["y"])

        _run.info = {
            **_run.info,
            "prop_positive": prop_positive,
            "n_is": nis,
        }

        if ray_cfg["redis_address"] is None:
            logger.info("Running Locally")
            ray.init(num_cpus=ray_cfg["num_cpus"],
                     num_gpus=ray_cfg["num_gpus"],
                     logging_level=logging.INFO,
                     local_mode=True)
        else:
            logger.info("Connecting to existing server")
            ray.init(redis_address=ray_cfg["redis_address"],
                     logging_level=logging.INFO)

        prior_params = {
            "w_nat_mean": np.zeros(d_in, dtype=np.float32),
            "w_pres": prior_pres * np.ones(d_in, dtype=np.float32)
        }

        logger.debug(
            f"Prior Parameters:\n\n{pretty_dump.dump(prior_params)}\n")

        def param_postprocess_function(delta_param, all_params, c):
            delta_ti = np_nest.apply_to_structure(
                lambda x: np.divide(x, len(c)), delta_param)

            ti_updates = []
            for client_index in c:
                new_client_params = np_nest.map_structure(
                    np.add, all_params[client_index], delta_ti)
                precisions = new_client_params['w_pres']
                precisions[precisions < 0] = 1e-5
                new_client_params['w_pres'] = precisions
                ti_update = np_nest.map_structure(np.subtract,
                                                  new_client_params,
                                                  all_params[client_index])
                ti_updates.append(ti_update)
                # logger.debug('*** CLIENT ***')
                # logger.debug(new_client_params)
                # logger.debug(ti_update)
                # logger.debug(all_params[client_index])
            return ti_updates

        param_postprocess_handle = lambda delta, all_params, c: param_postprocess_function(
            delta, all_params, c)

        ti_init = np_nest.map_structure(np.zeros_like, prior_params)
        # client factories for each client - this avoids pickling of the client object for ray internals
        client_factories = [
            StandardClient.create_factory(
                model_class=MeanFieldMultiDimensionalLogisticRegression,
                data=clients_data[i],
                model_parameters=ti_init,
                model_hyperparameters={
                    "base_optimizer_class": torch.optim.Adagrad,
                    "wrapped_optimizer_class": StandardOptimizer,
                    "base_optimizer_parameters": {
                        "lr": optimisation_settings["lr"],
                        "lr_decay": optimisation_settings["lr_decay"]
                    },
                    "wrapped_optimizer_parameters": {},
                    "N_steps": optimisation_settings["N_steps"],
                    "N_samples": N_samples,
                    "n_in": d_in,
                    "batch_size": optimisation_settings["L"],
                    "reset_optimiser": True,
                },
                hyperparameters={
                    "t_i_init_function": lambda x: np.zeros(x.shape),
                    "t_i_postprocess_function": postprocess_MF_logistic_ti,
                },
                metadata={
                    'client_index': i,
                    'test_self': {
                        'accuracy': compute_prediction_accuracy,
                        'log_lik': compute_log_likelihood
                    }
                }) for i in range(M)
        ]

        logger.info(f"Making M={M} Clients")

        # custom decorator based on passed in resources!
        remote_decorator = ray.remote(num_cpus=int(ray_cfg["num_cpus"]),
                                      num_gpus=int(ray_cfg["num_gpus"]))

        server = remote_decorator(
            DPSequentialIndividualPVIParameterServer).remote(
                model_class=MeanFieldMultiDimensionalLogisticRegression,
                dp_query_class=NumpyGaussianDPQuery,
                model_parameters=prior_params,
                hyperparameters={
                    "L": privacy_settings["L"],
                    "dp_query_parameters": {
                        "l2_norm_clip":
                        privacy_settings["C"],
                        "noise_stddev":
                        privacy_settings["C"] *
                        privacy_settings["sigma_relative"]
                    },
                    "lambda_postprocess_func": param_postprocess_handle,
                    "damping_factor": PVI_settings["damping_factor"],
                    "damping_decay": PVI_settings["damping_decay"],
                },
                max_iterations=N_iterations * (M / privacy_settings["L"]),
                # ensure each client gets updated N_iterations times
                client_factories=client_factories,
                prior=prior_params,
                accounting_dict={
                    "MomentAccountant": {
                        "accountancy_update_method":
                        moment_accountant.compute_online_privacy_from_ledger,
                        "accountancy_parameters": {
                            "target_delta": privacy_settings["target_delta"]
                        }
                    }
                })

        while not ray.get(server.should_stop.remote()):
            # dispatch work to ray and grab the log
            st_tick = time.time()
            ray.get(server.tick.remote())
            num_iterations = ray.get(server.get_num_iterations.remote())

            st_log = time.time()
            sacred_log = {}
            sacred_log["server"], _ = ray.get(server.log_sacred.remote())
            params = ray.get(server.get_parameters.remote())
            client_sacred_logs = ray.get(
                server.get_client_sacred_logs.remote())
            for i, log in enumerate(client_sacred_logs):
                sacred_log["client_" + str(i)] = log[0]
            sacred_log = numpy_nest.flatten(sacred_log, sep=".")

            st_pred = time.time()
            # predict every interval, and also for the last "interval" runs.
            if ((num_iterations - 1) % prediction["interval"] == 0) or (
                    N_iterations - num_iterations < prediction["interval"]):
                # y_pred_train = ray.get(server.get_model_predictions.remote(training_set))
                y_pred_test = ray.get(
                    server.get_model_predictions.remote(test_set))
                # sacred_log["train_all"] = compute_log_likelihood(y_pred_train, training_set["y"])
                # sacred_log["train_accuracy"] = compute_prediction_accuracy(y_pred_train, training_set["y"])
                sacred_log["test_all"] = compute_log_likelihood(
                    y_pred_test, test_set["y"])
                test_acc = compute_prediction_accuracy(y_pred_test,
                                                       test_set["y"])
                sacred_log["test_accuracy"] = test_acc

                # logger.debug('server server')
                # logger.debug(f'    acc: {sacred_log["train_accuracy"]}')
                # logger.debug(f'    acc: {sacred_log["train_all"]}')
            end_pred = time.time()

            for k, v in sacred_log.items():
                _run.log_scalar(k, v, num_iterations)
            end = time.time()

            logger.info(
                f"Server Ticket Complete\n"
                f"Server Timings:\n"
                f"  Server Tick: {st_log - st_tick:.2f}s\n"
                f"  Predictions: {end_pred - st_pred:.2f}s\n"
                f"  Logging:     {end - end_pred + st_pred - st_log:.2f}s\n\n"
                f"Parameters:\n"
                f" {pretty_dump.dump(params)}\n"
                f"Iteration Number:{num_iterations}\n")

        final_log = ray.get(server.get_compiled_log.remote())
        final_log["N_i"] = nis
        final_log["Proportion_positive"] = prop_positive
        t = datetime.datetime.now()

        ex.add_artifact(
            save_log(final_log, "full_log",
                     ex.get_experiment_info()["name"], experiment_tag,
                     logging_base_directory, _run.info["test"], t),
            "full_log.json")

        if save_t_is:
            t_is = [
                client.t_i for client in ray.get(server.get_clients.remote())
            ]
            ex.add_artifact(
                save_pickle(t_is, 't_is',
                            ex.get_experiment_info()["name"], experiment_tag,
                            logging_base_directory, _run.info["test"], t),
                't_is.pkl')

    except pyarrow.lib.ArrowIOError:
        raise Exception("Experiment Terminated - was this you?")

    return test_acc
Exemplo n.º 12
0
    def tick(self):
        if self.should_stop():
            return False

        lambda_old = self.parameters
        L = self.hyperparameters["L"]
        M = len(self.clients)

        # generate index
        if L > M:
            raise ValueError(
                'Need more clients than mini batch number of clients')
        c = np.random.choice(M, L, replace=False)
        logger.info(f"Selected clients {c}")
        # delta_is = [client.compute_update.remote(lambda_old) for client in self.clients]

        delta_is = []
        client_params = []

        # we want the da,ping factor to decay by one step when
        self.current_damping_factor = self.hyperparameters[
            "damping_factor"] * np.exp(-self.iterations * L / M *
                                       self.hyperparameters["damping_decay"])
        for indx, client in enumerate(self.clients):
            logger.info(f'On client {indx + 1} of {len(self.clients)}')
            client.set_hyperparameters(
                {"damping_factor": self.current_damping_factor})
            client_params.append(client.t_i)
            if indx in c:
                # selected to be updated
                delta_is.append(
                    client.get_update(model_parameters=lambda_old,
                                      model_hyperparameters=None,
                                      update_ti=False))

        sample_state = self.dp_query_with_ledgers.initial_sample_state(
            delta_is[0])
        sample_params = self.dp_query_with_ledgers.derive_sample_params(
            self.query_global_state)
        self.query_global_state = self.dp_query_with_ledgers.initial_global_state(
        )

        derived_data = defaultdict(list)
        for indx, delta_i in enumerate(delta_is):
            sample_state = self.dp_query_with_ledgers.accumulate_record(
                sample_params, sample_state, delta_i)
            for k, v in self.dp_query_with_ledgers.get_record_derived_data(
            ).items():
                derived_data[k].append(v)

        delta_i_tilde, _ = self.dp_query_with_ledgers.get_noised_result(
            sample_state, self.query_global_state, c)
        client_updates = self.hyperparameters["lambda_postprocess_func"](
            delta_i_tilde, client_params, c)

        lambda_new = lambda_old
        for update in client_updates:
            lambda_new = np_nest.map_structure(np.add, lambda_new, update)

        self.parameters = lambda_new
        formatted_ledgers = self.dp_query_with_ledgers.get_formatted_ledgers()

        logger.debug(f"l2 clipping norms: {derived_data}")
        for k, v in derived_data.items():
            # summarise statistics instead
            derived_data[k] = np.percentile(np.array(v),
                                            [10.0, 30.0, 50.0, 70.0, 90.0])

        for client_index, client_update in zip(c, client_updates):
            self.clients[client_index].set_metadata(
                {"global_iteration": self.iterations})
            self.clients[client_index].update_ti(client_update)
            for k, v in self.accountants[client_index].items():
                v.update_privacy(formatted_ledgers[client_index])

        self.model.set_parameters(self.parameters)
        logger.debug(
            f"Iteration {self.iterations} complete.\nNew Parameters:\n {pretty_dump.dump(lambda_new)}\n"
        )

        self.log_update()
        self.log["derived_data"].append(
            structured_ndarrays_to_lists(derived_data))

        self.iterations += 1
Exemplo n.º 13
0
    def tick(self):
        if self.should_stop():
            return False

        lambda_old = self.parameters
        lambda_new = copy.deepcopy(lambda_old)

        logger.debug("Getting Client Updates")

        self.current_damping_factor = self.hyperparameters[
            "damping_factor"] * np.exp(
                -self.iterations * self.hyperparameters["damping_decay"])

        communications_this_round = 0

        for i in range(len(self.clients)):

            available_clients = [
                client.can_update() for client in self.clients
            ]

            if not np.any(available_clients):
                self.all_clients_stopped = True
                logger.info('All clients report to be finished. Stopping.')
                break

            client_index = int(
                np.random.choice(len(self.clients),
                                 1,
                                 replace=False,
                                 p=self.client_probs))
            logger.debug(f"Selected Client {client_index}")
            client = self.clients[client_index]

            if not client.can_update():
                logger.debug(
                    f"Skipping client {client_index}, client not avalible to update."
                )
                continue

            logger.debug(
                f'On client {i + 1} of {len(self.clients)}, client index {client_index}'
            )
            client.set_hyperparameters(
                {"damping_factor": self.current_damping_factor})
            delta_i = client.get_update(model_parameters=lambda_new,
                                        model_hyperparameters=None,
                                        update_ti=True)
            lambda_new = np_nest.map_structure(np.add, lambda_new, delta_i)
            logger.debug(f'Finished Client {i + 1} of {len(self.clients)}\n\n')

            communications_this_round += 1

        self.client_ti_norms = []
        for i in range(len(self.clients)):
            self.client_ti_norms.append(
                np.sqrt(
                    np_nest.reduce_structure(lambda p: np.linalg.norm(p)**2,
                                             np.add, self.clients[i].t_i)))

        self.parameters = lambda_new

        # update the model parameters
        self.model.set_parameters(self.parameters)
        logger.debug(
            f"Iteration {self.iterations} complete.\nNew Parameters:\n {pretty_dump.dump(lambda_new)}\n"
        )

        str = f"Iteration {self.iterations}\n\n"
        for k, v in lambda_new.items():
            str = str + f"mean_{k}: {np.sqrt(np.mean(v**2))}\n"

        logger.info(str)
        [
            client.set_metadata({"global_iteration": self.iterations})
            for client in self.clients
        ]

        self.log_update()

        self.iterations += 1

        return communications_this_round