def get_noised_result(self, sample_state, global_state): """Ensures sample is recorded to the ledger and returns noised result.""" result, new_global_state = self._query.get_noised_result( sample_state, global_state) self._ledger.finalise_sample() op = lambda tensor: tensor.clone().detach() return nest.map_structure(op, result), new_global_state
def get_noised_result(self, sample_state, global_state): def add_noise(p): return p + (torch.randn_like(p) * global_state.noise_stddev) if self._ledger: self._ledger.record_sum_query(global_state.l2_norm_clip, global_state.noise_stddev) return nest.map_structure(add_noise, sample_state), global_state
def get_noised_result(self, sample_state, global_state, selected_indices): """Ensures sample is recorded to the ledger and returns noised result.""" result, new_global_state = self._query.get_noised_result( sample_state, global_state) # record sum queries for each client who was selected for i in range(self.M): if i in selected_indices: self.ledgers[i].record_sum_query(global_state.l2_norm_clip, global_state.noise_stddev) self.ledgers[i].finalise_sample() if isinstance(result, torch.Tensor): op = lambda tensor: tensor.clone().detach() return nest.map_structure(op, result), new_global_state else: return result, new_global_state
def preprocess_record(self, params, record): """ Return the scaled record and also the l2 norm (to deduce whether clipping occured or not) :param params: :param record: :return: """ l2_norm_clip = params # logger.debug(f"clipping bound {l2_norm_clip}") l2_norm = torch.sqrt(nest.reduce_structure(lambda p: torch.norm(torch.flatten(p), p=2) ** 2, torch.add, record)) self._record_l2_norm = l2_norm if l2_norm < l2_norm_clip: return record else: return nest.map_structure(lambda p: torch.div(p, torch.abs(l2_norm / l2_norm_clip)), record)
def initial_sample_state(self, param_groups): return nest.map_structure(torch.zeros_like, param_groups)
def merge_sample_states(self, sample_state_1, sample_state_2): return nest.map_structure(torch.add, sample_state_1, sample_state_2)
def accumulate_preprocessed_record(self, sample_state, record): return nest.map_structure(torch.add, sample_state, record)
def initial_sample_state(self, param_groups): """ Return state of zeros the same shape as the parameter groups.""" return nest.map_structure(torch.zeros_like, param_groups)