def mask(x): assert isinstance(x, PrivateTensor) node_key = ('mask', x) masked = _nodes.get(node_key, None) if masked is None: x0, x1 = x.unwrapped shape = x.shape with tf.name_scope('mask'): with tf.device(get_active_protocol().crypto_producer.device_name): a = sample(shape) a0, a1 = share(a) with tf.device(get_active_protocol().server_0.device_name): alpha0 = crt_sub(x0, a0) with tf.device(get_active_protocol().server_1.device_name): alpha1 = crt_sub(x1, a1) # exchange of alphas with tf.device(get_active_protocol().server_0.device_name): alpha_on_0 = reconstruct(alpha0, alpha1) with tf.device(get_active_protocol().server_1.device_name): alpha_on_1 = reconstruct(alpha0, alpha1) masked = MaskedPrivateTensor(x, a, a0, a1, alpha_on_0, alpha_on_1) _nodes[node_key] = masked return masked
def scale(x, k, apply_encoding=None): assert isinstance(x, PrivateTensor) assert type(k) in [int, float] x0, x1 = x.unwrapped if apply_encoding is None: # determine automatically apply_encoding = type(k) is float c = np.array([k]) if apply_encoding: c = encode(c) c = decompose(c) with tf.name_scope('scale'): with tf.device(get_active_protocol().server_0.device_name): y0 = crt_scale(x0, c) with tf.device(get_active_protocol().server_1.device_name): y1 = crt_scale(x1, c) y = PrivateTensor(y0, y1) if apply_encoding: y = truncate(y) return y
def concat(ys): # FIXME[Morten] add support for PrivateTensors as well def helper(tensors): # as an example, assume shape is [[(1000,2); 10]; 3] tensors = tf.concat(tensors, axis=1) # now shape is (10,3000,2) tensors = tf.split(tensors, 10, axis=0) # now shape is [(1,3000,2); 10] tensors = [tf.reshape(tensor, tensor.shape[1:]) for tensor in tensors] # now shape is [(3000,2); 10] return tensors with tf.name_scope('concat'): y0s, y1s = zip(*[y.unmasked.unwrapped for y in ys]) bs, b0s, b1s, beta_on_0s, beta_on_1s = zip(*[y.unwrapped for y in ys]) with tf.device(get_active_protocol().crypto_producer.device_name): b = helper(bs) with tf.device(get_active_protocol().server_0.device_name): y0 = helper(y0s) b0 = helper(b0s) beta_on_0 = helper(beta_on_0s) with tf.device(get_active_protocol().server_1.device_name): y1 = helper(y1s) b1 = helper(b1s) beta_on_1 = helper(beta_on_1s) y = PrivateTensor(y0, y1) y_masked = MaskedPrivateTensor(y, b, b0, b1, beta_on_0, beta_on_1) return y_masked
def truncate(x): assert isinstance(x, PrivateTensor) x0, x1 = x.share0, x.share1 with tf.name_scope('truncate'): with tf.device(get_active_protocol().server_0.device_name): y0 = raw_truncate(x0) with tf.device(get_active_protocol().server_1.device_name): y1 = crt_sub(M_wrapped, raw_truncate(crt_sub(M_wrapped, x1))) return PrivateTensor(y0, y1)
def assign(x, v): assert isinstance(x, PrivateTensor) assert isinstance(v, PrivateTensor) x0, x1 = x.share0, x.share1 v0, v1 = v.share0, v.share1 with tf.name_scope("assign"): with tf.device(get_active_protocol().server_0.device_name): y0 = [tf.assign(xi, vi) for xi, vi in zip(x0, v0)] with tf.device(get_active_protocol().server_1.device_name): y1 = [tf.assign(xi, vi) for xi, vi in zip(x1, v1)] return y0, y1
def sigmoid(x): assert isinstance(x, PrivateTensor) w0 = 0.5 w1 = 0.2159198015 w3 = -0.0082176259 w5 = 0.0001825597 w7 = -0.0000018848 w9 = 0.0000000072 with tf.name_scope('sigmoid'): # TODO optimise depth x2 = square(x) x3 = mul(x2, x) x5 = mul(x2, x3) x7 = mul(x2, x5) x9 = mul(x2, x7) y1 = scale(x, w1) y3 = scale(x3, w3) y5 = scale(x5, w5) y7 = scale(x7, w7) y9 = scale(x9, w9) with tf.device(get_active_protocol().server_0.device_name): z0 = crt_add( y1.share0, crt_add( y3.share0, crt_add( y5.share0, crt_add( y7.share0, crt_add(y9.share0, decompose(encode(np.array([w0])))))))) with tf.device(get_active_protocol().server_1.device_name): z1 = crt_add( y1.share1, crt_add(y3.share1, crt_add(y5.share1, crt_add(y7.share1, y9.share1)))) z = PrivateTensor(z0, z1) return z
def dot(x, y): node_key = ('dot', x, y) z = _nodes.get(node_key, None) if z is None: if isinstance(x, PrivateTensor): x = mask(x) if isinstance(y, PrivateTensor): y = mask(y) assert isinstance(x, MaskedPrivateTensor) assert isinstance(y, MaskedPrivateTensor) a, a0, a1, alpha_on_0, alpha_on_1 = x.unwrapped b, b0, b1, beta_on_0, beta_on_1 = y.unwrapped with tf.name_scope('dot'): with tf.device(get_active_protocol().crypto_producer.device_name): ab = crt_dot(a, b) ab0, ab1 = share(ab) with tf.device(get_active_protocol().server_0.device_name): alpha = alpha_on_0 beta = beta_on_0 z0 = crt_add( ab0, crt_add(crt_dot(a0, beta), crt_add(crt_dot(alpha, b0), crt_dot(alpha, beta)))) with tf.device(get_active_protocol().server_1.device_name): alpha = alpha_on_1 beta = beta_on_1 z1 = crt_add(ab1, crt_add(crt_dot(a1, beta), crt_dot(alpha, b1))) z = PrivateTensor(z0, z1) z = truncate(z) _nodes[node_key] = z return z
def split(y, num_splits): assert isinstance(y, MaskedPrivateTensor) # FIXME[Morten] add support for PrivateTensors as well y0, y1 = y.unmasked.unwrapped b, b0, b1, beta_on_0, beta_on_1 = y.unwrapped def helper(tensors): # FIXME[Morten] all this reshaping seems to encur a big hit on (at least) graph building # as an example, assume shape is [(3000,2); 10] tensors = tf.stack(tensors) # now shape is (10,3000,2) tensors = tf.split(tensors, num_splits, axis=1) # now shape is [(10,30,2); 100] if num_splits == 100 tensors = [[ tf.reshape(xi, xi.shape[1:]) for xi in tf.split(tensor, 10, axis=0) ] for tensor in tensors] # now shape is [[(30,2); 10]; 100] return tensors with tf.name_scope('split'): with tf.device(get_active_protocol().crypto_producer.device_name): bs = helper(b) with tf.device(get_active_protocol().server_0.device_name): y0s = helper(y0) b0s = helper(b0) beta_on_0s = helper(beta_on_0) with tf.device(get_active_protocol().server_1.device_name): y1s = helper(y1) b1s = helper(b1) beta_on_1s = helper(beta_on_1) tensors = [] for y0, y1, b, b0, b1, beta_on_0, beta_on_1 in zip(y0s, y1s, bs, b0s, b1s, beta_on_0s, beta_on_1s): y = PrivateTensor(y0, y1) y_masked = MaskedPrivateTensor(y, b, b0, b1, beta_on_0, beta_on_1) tensors.append(y_masked) return tensors
def square(x): node_key = ('square', x) y = _nodes.get(node_key, None) if y is None: if isinstance(x, PrivateTensor): x = mask(x) assert isinstance(x, MaskedPrivateTensor) a, a0, a1, alpha_on_0, alpha_on_1 = x.unwrapped with tf.name_scope('square'): with tf.device(get_active_protocol().crypto_producer.device_name): aa = crt_mul(a, a) aa0, aa1 = share(aa) with tf.device(get_active_protocol().server_0.device_name): alpha = alpha_on_0 y0 = crt_add( aa0, crt_add( crt_mul(a0, alpha), crt_add( crt_mul(alpha, a0), # TODO replace with `scale(, 2)` op crt_mul(alpha, alpha)))) with tf.device(get_active_protocol().server_1.device_name): alpha = alpha_on_1 y1 = crt_add(aa1, crt_add(crt_mul(a1, alpha), crt_mul( alpha, a1))) # TODO replace with `scale(, 2)` op y = PrivateTensor(y0, y1) y = truncate(y) _nodes[node_key] = y return y
def define_variable(initial_value, apply_encoding=True, name=None): v = initial_value v = encode(v) if apply_encoding else v v = decompose(v) v0, v1 = share(v) with tf.name_scope('var{}'.format('-' + name if name else '')): with tf.device(get_active_protocol().server_0.device_name): vars0 = [tf.Variable(vi, dtype=INT_TYPE) for vi in v0] init0 = [vi.initializer for vi in vars0] x0 = [vi.read_value() for vi in vars0] with tf.device(get_active_protocol().server_1.device_name): vars1 = [tf.Variable(vi, dtype=INT_TYPE) for vi in v1] init1 = [vi.initializer for vi in vars1] x1 = [vi.read_value() for vi in vars1] x = PrivateTensor(x0, x1) return x, init0 + init1
def transpose(x): assert isinstance(x, PrivateTensor) x0, x1 = x.share0, x.share1 with tf.name_scope('transpose'): with tf.device(get_active_protocol().server_0.device_name): x0_t = [tf.transpose(t) for t in x0] with tf.device(get_active_protocol().server_1.device_name): x1_t = [tf.transpose(t) for t in x1] x_t = PrivateTensor(x0_t, x1_t) x_masked = _nodes.get(('mask', x), None) if x_masked: # use mask for `x` to get mask for `x_t` a, a0, a1, alpha_on_0, alpha_on_1 = x_masked.unwrapped with tf.device(get_active_protocol().crypto_producer.device_name): a_t = [tf.transpose(t) for t in a] with tf.device(get_active_protocol().server_0.device_name): a0_t = [tf.transpose(t) for t in a0] alpha_on_0_t = [tf.transpose(t) for t in alpha_on_0] with tf.device(get_active_protocol().server_1.device_name): a1_t = [tf.transpose(t) for t in a1] alpha_on_1_t = [tf.transpose(t) for t in alpha_on_1] x_masked_t = MaskedPrivateTensor(x_t, a_t, a0_t, a1_t, alpha_on_0_t, alpha_on_1_t) _nodes[('mask', x_t)] = x_masked_t return x_t
def sub(x, y): assert isinstance(x, PrivateTensor) assert isinstance(y, PrivateTensor) node_key = ('sub', x, y) z = _nodes.get(node_key, None) if z is None: x0, x1 = x.unwrapped y0, y1 = y.unwrapped with tf.name_scope("sub"): with tf.device(get_active_protocol().server_0.device_name): z0 = crt_sub(x0, y0) with tf.device(get_active_protocol().server_1.device_name): z1 = crt_sub(x1, y1) z = PrivateTensor(z0, z1) _nodes[node_key] = z return z
def cache(x, initializers=None, updators=None): if updators is None: updators = global_cache_updators # TODO[Morten] use `initializers` node_key = ('cache', x) cached = _nodes.get(node_key, None) if cached is None: if isinstance(x, PrivateTensor): x0, x1 = x.unwrapped with tf.name_scope('cache'): with tf.device(get_active_protocol().server_0.device_name): cached_x0 = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(x0, m) ] updators.append([ tf.assign(var, val) for var, val in zip(cached_x0, x0) ]) with tf.device(get_active_protocol().server_1.device_name): cached_x1 = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(x1, m) ] updators.append([ tf.assign(var, val) for var, val in zip(cached_x1, x1) ]) # TODO[Morten] wrap PrivateTensor around var.read_value() instead to ensure updated values? cached = PrivateTensor(cached_x0, cached_x1) _nodes[node_key] = cached elif isinstance(x, MaskedPrivateTensor): a, a0, a1, alpha_on_0, alpha_on_1 = x.unwrapped cached_x = cache(x.unmasked, initializers, updators) with tf.name_scope('cache'): with tf.device( get_active_protocol().crypto_producer.device_name): cached_a = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(a, m) ] updators.append( [tf.assign(var, val) for var, val in zip(cached_a, a)]) with tf.device(get_active_protocol().server_0.device_name): cached_a0 = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(a0, m) ] updators.append([ tf.assign(var, val) for var, val in zip(cached_a0, a0) ]) cached_alpha_on_0 = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(alpha_on_0, m) ] updators.append([ tf.assign(var, val) for var, val in zip(cached_alpha_on_0, alpha_on_0) ]) with tf.device(get_active_protocol().server_1.device_name): cached_a1 = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(a1, m) ] updators.append([ tf.assign(var, val) for var, val in zip(cached_a1, a1) ]) cached_alpha_on_1 = [ tf.Variable(tf.random_uniform(shape=vi.shape, maxval=mi, dtype=INT_TYPE), dtype=INT_TYPE) for vi, mi in zip(alpha_on_1, m) ] updators.append([ tf.assign(var, val) for var, val in zip(cached_alpha_on_1, alpha_on_1) ]) # TODO[Morten] wrap MaskedPrivateTensor around var.read_value() instead to ensure updated values? cached = MaskedPrivateTensor(cached_x, cached_a, cached_a0, cached_a1, cached_alpha_on_0, cached_alpha_on_1) _nodes[node_key] = cached else: raise AssertionError("'x' not of supported type") return cached