forked from cilvento/b2_exponential_mechanism
-
Notifications
You must be signed in to change notification settings - Fork 0
/
expmech.py
306 lines (267 loc) · 11.4 KB
/
expmech.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# The Base-2 Exponential Mechanism Reference Implementation
# author: C. Ilvento, cilvento@gmail.com
# status: reference implementation only
import gmpy2
from gmpy2 import mpfr, mpz
class ExpMech:
""" The ExpMech base class. """
# Check that the appropriate precision and context settings are in place
def check_context(self):
""" Check that the context settings including precision and inexact arithmetic
flags are set properly. """
if self.precision_set != True:
return False
ctx = gmpy2.get_context()
if ctx.precision != self.context.precision:
return False
else:
return (ctx.trap_inexact and ctx.trap_overflow and ctx.trap_erange \
and ctx.trap_divzero and ctx.trap_invalid and ctx.trap_expbound \
and ctx.trap_underflow)
# A method to test whether the current precision is sufficient for intended usage.
def check_precision(self):
""" Performs test computations at the current precision intended to capture the
workload of the exponential mechanism and catch if the current precision is
sufficient. """
# 1. Compute base = 2^(-eta)
self.base = mpfr(pow(self.eta_x,self.eta_z))
self.base *= mpfr(gmpy2.exp2(-gmpy2.mul(self.eta_y,self.eta_z)))
# 2. Compute (base)^u_min and (base)^u_max
min_weight = pow(self.base, self.u_min)
max_weight = pow(self.base, self.u_max)
mm = max_weight + min_weight
# 3. Compute maximum total utility and minimum total utility
max_total = max_weight*self.max_outcomes
min_total = min_weight*self.max_outcomes
# 4. Add max and min total utilities
max_min_total = max_total + min_weight
min_max_total = min_total + max_weight
# Initialize a new mechanism
def __init__(self, rng, eta_x = 1, eta_y = 0, eta_z = 1, \
u_min = 10, u_max = 0, max_O=100, \
min_sampling_precision = 10):
""" Initializes a new ExpMech object including computing the required precision
and setting inexact arithmetic exceptions.
**Args**:
rng (function): a random bit generator;
eta_x (int): privacy parameter;
eta_y (int): privacy parameter;
eta_z (int): privacy parameter;
u_min (int): the minimum utility (maximum magnitude);
u_max (int): the maximum utility (maximum magnitude);
max_O (int): the maximum size of the outcome space;
min_sampling_precision (int): the minimum precision at which to sample for randomized rounding
"""
# initialize precision_set to False
self.precision_set = False
""" Status indicator for whether the precision necessary has been computed yet. """
# Set the gmpy2 library context to trap on inexact arithmetic, overflows, underflows, etc.
ctx = gmpy2.get_context()
ctx.trap_inexact = True
ctx.trap_overflow = True
ctx.trap_divzero = True
ctx.trap_invalid = True
ctx.trap_underflow = True
ctx.trap_expbound = True
ctx.trap_erange = True
# Set the rng, privacy parameters and utility bounds
self.rng = rng
""" Random bit generator used for generating all randomness in the procedures. """
try: # try to cast the integer-valued arguments to mpz - raise an error if non-integer
self.eta_x = mpz(eta_x)
self.eta_y = mpz(eta_y)
self.eta_z = mpz(eta_z)
self.u_min = mpz(u_min)
self.u_max = mpz(u_max)
self.max_outcomes = mpz(max_O)
except gmpy2.InexactResultError:
raise RuntimeError('Non-integer parameter when integer was expected')
# Compute the required precision for the desired parameters
# start from a small-ish precision so we don't waste unnecessary bits
ctx = gmpy2.get_context()
ctx.precision = 16
while self.precision_set != True:
if gmpy2.get_context().precision >= gmpy2.get_max_precision():
raise RuntimeError('Failed to set precision: maximum precision exceeded.')
else:
try:
self.check_precision()
except gmpy2.InexactResultError:
ctx = gmpy2.get_context()
ctx.precision = 2*ctx.precision
else:
self.precision_set = True
if self.precision_set != True:
raise RuntimeError('Failed to set precision.')
else:
ctx = gmpy2.get_context()
ctx.clear_flags() # clear any flags
self.context = ctx.copy() # store a copy to test future state
""" The gmpy2 context required for computations on this mechanism. """
assert self.check_context(), 'Context invalid.'
# Set the minimum sampling precision, which cannot be greater than the context precision
self.min_sampling_precision = min(min_sampling_precision, self.context.precision)
""" The minimum precision at which to perform sampling for randomized rounding. """
# Sample a random value with p bits of precision
# from a given starting power of 2. Output is in [0,2^{start_pow+1})
def get_random_value(self,start_pow, p=None):
""" Sample a random value based on self.rng of p bits between [0,2^(start_pow+1)). """
if p == None:
p = self.context.precision
s = mpfr(0)
nextbit = gmpy2.exp2(start_pow)
for i in range(1, p):
s = gmpy2.add(s, gmpy2.mul(nextbit,mpfr(self.rng())))
nextbit = gmpy2.exp2(start_pow - i)
return s
# Randomized rounding logic
def randomized_round(self, x):
""" Round the input to an integer value. Value is rounded up with probability (x - floor(x)).
Rounding randomness is sampled at min_sampling_precision."""
s = self.get_random_value(-1, self.min_sampling_precision)
output = int(x)+1
cutoff = x - int(x)
if s > cutoff:
output = int(x)
output = min(max(self.u_max, output),self.u_min)
return output
# Set the utility function, wrapped in randomized rounding logic
# INPUTS: util, a utility function taking a single element as argument
def set_utility(self, util):
""" Set the utility function, wrapped in randomized rounding logic.
Args:
u (function): a function taking a single argument from the outcome space returning a single real value.
"""
self.u = lambda x: self.randomized_round(util(x))
# Sample an index from W according to the normalized weight of each entry using
# randomness from rng
def normalized_sample(self, W):
""" Normalized sampling without division.
Args:
W: a set of weights from which to sample.
Returns: an integer in [0,len(W)] corresponding to the index sampled.
"""
t = gmpy2.fsum(W) # compute total weight
C = [gmpy2.fsum(W[0:i+1]) for i in range(0, len(W))] # compute cumulative weights
# Determine the maximum power of two for sampling
i_max = 0
while gmpy2.exp2(i_max) > t:
i_max -= 1
while gmpy2.exp2(i_max) <= t:
i_max += 1
# sample a random number
s = gmpy2.exp2(i_max + 1)
while s > t:
s = self.get_random_value(i_max,self.context.precision)
# return the element that matches the sampled index
for i in range(0, len(W)):
if C[i] >= s:
return i
# Optimized normalized sampling logic.
def optimized_normalized_sample(self,W):
""" Optimized Normalized sampling without division.
Args:
W: a set of weights from which to sample.
Returns: an integer in [0,len(W)] corresponding to the index sampled.
WARNING: introduces a timing channel for differing weight distributions.
"""
t = gmpy2.fsum(W) # compute total weight
C = [gmpy2.fsum(W[0:i+1]) for i in range(0, len(W))] # compute cumulative weights
s = 0
log2t = 0
while gmpy2.exp2(log2t) > t:
log2t -= 1
while gmpy2.exp2(log2t) <= t:
log2t += 1
j = log2t - 1
remaining = [i for i in range(0,len(W))]
if t < gmpy2.exp2(log2t):
remaining.append(len(W)) # add a dummy value
C.append(gmpy2.exp2(log2t))
while len(remaining)>1:
r = self.rng()
s = s + r*gmpy2.exp2(j)
to_remove = []
for i in remaining: # check if each remaining index is still reachable
if C[i] <= s:
to_remove.append(i)
if i > 0:
if C[i-1] >= s + gmpy2.exp2(j):
to_remove.append(i)
for i in to_remove:
remaining.remove(i)
if len(remaining) == 1 and remaining[0]==len(W):
s = 0
j = log2t # don't subtract 1, it's going to be decremented
remaining = [i for i in range(0,len(W))]
if t < gmpy2.exp2(log2t):
remaining.append(len(W)) # add a dummy value
C.append(gmpy2.exp2(log2t))
j -= 1
return remaining[0]
# Exact exponential mechanism
def exact_exp_mech(self, O, optimized_sample = False):
""" Run the mechanism over the outcome space O.
Returns a single element from O sampled from the exponential mechanism.
Defaults to un-optimized sampling logic.
Setting optimized_sample=True can result in timing channels.
"""
# check that O matches size requirements
if len(O) > self.max_outcomes:
raise RuntimeError('Outcome space size too large.')
# Get utilities
U = [self.u(o) for o in O]
self.check_context()
# Compute weights
W = [pow(mpfr(self.base),mpfr(u)) for u in U]
self.check_context()
# Sample
if optimized_sample == True:
return O[self.optimized_normalized_sample(W)]
else:
return O[self.normalized_sample(W)]
class LaplaceMech(ExpMech):
""" The LaplaceMech child class. Implements the clamped Laplace mechanism utility function over a discrete set of outcomes.
Samples from the outcome space [b_min,b_max] at granularity gamma. """
def __init__(self, rng, x, sensitivity, eta_x = 1, eta_y = 0, eta_z = 1, \
b_min = -10, b_max = 10, gamma = 2**(-4), \
min_sampling_precision = 10):
""" Initializes the LaplaceMech including computing the required precision.
Args:
rng (function): a random bit generator;
x (float or int): the target value of the mechanism
sensitivity: the sensitivity associated with the computation of x
eta_x (int): privacy parameter;
eta_y (int): privacy parameter;
eta_z (int): privacy parameter;
b_min (float or int): the lower bound of the output range;
b_max (float or int): the upper bound of the output range;
gamma (float): the discretization granularity
min_sampling_precision (int): the minimum precision at which to sample for randomized rounding
"""
# compute outcome space
O = []
b = b_min
while b <= b_max:
O.append(b)
b += gamma
self.Outcomes = O
""" The outcome space used by the mechanism [b_min,b_max] at granularity gamma. """
# clamp x to range
x = min(max(b_min,x),b_max)
# specify utility function
if sensitivity <= 0:
raise RuntimeError('Sensitivity must be greater than 0.')
u = lambda y: abs(x - y)/sensitivity
max_u = 0
min_u = int((b_max - b_min)/sensitivity) + 1
# initialize the base class
max_O = len(self.Outcomes)+1
ExpMech.__init__(self,rng,eta_x,eta_y,eta_z,min_u,max_u,max_O,min_sampling_precision)
self.set_utility(u)
def run_mechanism(self, optimized_sample = False):
""" Runs the mechanism and returns a single outcome from O.
Defaults to un-optimized sampling logic.
Setting optimized_sample=True can result in timing channels."""
O = self.Outcomes
return self.exact_exp_mech(O,optimized_sample=optimized_sample)