def get_noised_result(self, sample_state, global_state):
     """Ensures sample is recorded to the ledger and returns noised result."""
     result, new_global_state = self._query.get_noised_result(
         sample_state, global_state)
     self._ledger.finalise_sample()
     op = lambda tensor: tensor.clone().detach()
     return nest.map_structure(op, result), new_global_state
    def get_noised_result(self, sample_state, global_state):
        def add_noise(p):
            return p + (torch.randn_like(p) * global_state.noise_stddev)

        if self._ledger:
            self._ledger.record_sum_query(global_state.l2_norm_clip, global_state.noise_stddev)

        return nest.map_structure(add_noise, sample_state), global_state
    def get_noised_result(self, sample_state, global_state, selected_indices):
        """Ensures sample is recorded to the ledger and returns noised result."""
        result, new_global_state = self._query.get_noised_result(
            sample_state, global_state)
        # record sum queries for each client who was selected
        for i in range(self.M):
            if i in selected_indices:
                self.ledgers[i].record_sum_query(global_state.l2_norm_clip,
                                                 global_state.noise_stddev)
                self.ledgers[i].finalise_sample()

        if isinstance(result, torch.Tensor):
            op = lambda tensor: tensor.clone().detach()
            return nest.map_structure(op, result), new_global_state
        else:
            return result, new_global_state
    def preprocess_record(self, params, record):
        """
        Return the scaled record and also the l2 norm (to deduce whether clipping occured or not)

        :param params:
        :param record:
        :return:
        """
        l2_norm_clip = params
        # logger.debug(f"clipping bound {l2_norm_clip}")
        l2_norm = torch.sqrt(nest.reduce_structure(lambda p: torch.norm(torch.flatten(p), p=2) ** 2,
                                                   torch.add,
                                                   record))
        self._record_l2_norm = l2_norm
        if l2_norm < l2_norm_clip:
            return record
        else:
            return nest.map_structure(lambda p: torch.div(p, torch.abs(l2_norm / l2_norm_clip)), record)
 def initial_sample_state(self, param_groups):
     return nest.map_structure(torch.zeros_like, param_groups)
 def merge_sample_states(self, sample_state_1, sample_state_2):
     return nest.map_structure(torch.add, sample_state_1, sample_state_2)
 def accumulate_preprocessed_record(self, sample_state, record):
     return nest.map_structure(torch.add, sample_state, record)
 def initial_sample_state(self, param_groups):
     """ Return state of zeros the same shape as the parameter groups."""
     return nest.map_structure(torch.zeros_like, param_groups)