def test_server(): init_communicate(Config.server_rank) prot = SharesMultServer(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server Online"): prot.online(a_s, b_s) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") blob_u_c.prepare_recv() blob_v_c.prepare_recv() blob_z_c.prepare_recv() torch_sync() u_c = blob_u_c.get_recv() v_c = blob_v_c.get_recv() z_c = blob_z_c.get_recv() u = pmod(prot.u_s + u_c, modulus) v = pmod(prot.v_s + v_c, modulus) check_correctness_online(u, v, prot.z_s, z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") blob_c_c.prepare_recv() torch_sync() c_c = blob_c_c.get_recv() check_correctness_online(a, b, prot.c_s, c_c) end_communicate()
def test_torch_ntt(): modulus = 786433 img_hw = 64 filter_hw = 3 padding = 1 data_bit = 17 len_vector = img_hw data_range = 2**data_bit root, mod, ntt_mat, inv_mat = generate_ntt_param(modulus, len_vector, data_range) x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int) w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int) x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, img_hw**2).reshape([img_hw, img_hw]).double() w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, filter_hw**2).reshape([filter_hw, filter_hw]).double() ntt_mat = ntt_mat.double() inv_mat = inv_mat.double() with NamedTimerInstance("Mat NTT 2d"): ntted = ntt_mat2d(ntt_mat, mod, x) reved = ntt_mat2d(inv_mat, mod, ntted) expected = pmod(x, modulus).type(torch.int) actual = pmod(reved, modulus).type(torch.int) compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def conv2d(self, x, w): self.load_and_ntt_w(w) self.load_and_ntt_x(x) sub_x = x[0] sub_ntted_x = self.ntted_x[0] inv_sub_ntted_x = self.ntt_matmul.intt2d(sub_ntted_x.double()) trun_inv_sub_ntted_x = inv_sub_ntted_x[:self.img_hw, :self.img_hw] compare_expected_actual(pmod(sub_x, self.modulus), pmod(trun_inv_sub_ntted_x, self.modulus), get_relative=True, name="sub_x") sub_w = w[0, 0] sub_ntted_w = self.ntted_w[0, 0] inv_sub_ntted_w = self.ntt_matmul.intt2d(sub_ntted_w.double()) trun_inv_sub_ntted_w = inv_sub_ntted_w[:self.filter_hw, :self. filter_hw] expected = pmod(sub_w, self.modulus).rot90(2) actual = pmod(trun_inv_sub_ntted_w, self.modulus) compare_expected_actual(expected, actual, get_relative=True, name="sub_w") dotted = self.conv2d_ntted_single_channel(sub_ntted_x, sub_ntted_w) sub_y = self.transform_y_single_channel(dotted) # sub_w = w[0, 0] # sub_ntted_w = self.ntted_w[0] return self.conv2d_loaded()
def online(self, input_s): input_s = input_s.to(self.comp_device) masked_input_s = pmod(input_s + self.input_mask_s, self.modulus) self.blob_masked_input_s.send(masked_input_s) masked_output_s = self.blob_masked_output_s.get_recv() self.output_s = pmod(masked_output_s - self.input_mask_s, self.modulus)
def sum_c_i_offline(self, delta_a, fhe_beta_i_c, fhe_alpha_beta_xor_c, s, alpha_i, ci_mask_s, mult_mask_s, shuffle_order): # the last row of sum_xor is c_{-1}, which helps check the case with x == y fhe_builder = self.fhe_builder_16 # fhe_sum_xor = [fhe_builder.build_enc(self.num_elem) for i in range(self.num_work_batch)] fhe_sum_xor = [None for i in range(self.num_work_batch)] fhe_sum_xor[self.work_bit - 1] = fhe_builder.build_enc(self.num_elem) for i in range(self.work_bit - 1)[::-1]: fhe_sum_xor[i] = fhe_sum_xor[i + 1].copy() fhe_sum_xor[i] += fhe_alpha_beta_xor_c[i + 1] fhe_delta_a = fhe_builder.build_plain_from_torch(delta_a) fhe_sum_xor[self.work_bit] = fhe_sum_xor[0].copy() fhe_sum_xor[self.work_bit] += fhe_alpha_beta_xor_c[0] fhe_sum_xor[self.work_bit] += fhe_delta_a del fhe_delta_a for i in range(self.work_bit)[::-1]: fhe_mult_3 = fhe_builder.build_plain_from_torch( pmod(3 * mult_mask_s[i].cpu(), self.q_16)) fhe_mult_mask_s = fhe_builder.build_plain_from_torch( mult_mask_s[i]) masked_s = pmod( s.type(torch.int64) * mult_mask_s[i].type(torch.int64), self.q_16).type(torch.float32) # print("s * mult_mask_s[i]", torch.max(masked_s)) fhe_s = fhe_builder.build_plain_from_torch(masked_s) fhe_alpha_i = fhe_builder.build_plain_from_torch(alpha_i[i] * mult_mask_s[i]) fhe_ci_mask_s = fhe_builder.build_plain_from_torch(ci_mask_s[i]) fhe_beta_i_c[i] *= fhe_mult_mask_s fhe_sum_xor[i] *= fhe_mult_3 fhe_sum_xor[i] -= fhe_beta_i_c[i] fhe_sum_xor[i] += fhe_s fhe_sum_xor[i] += fhe_alpha_i fhe_sum_xor[i] += fhe_ci_mask_s del fhe_mult_3, fhe_mult_mask_s, fhe_s, fhe_alpha_i, fhe_ci_mask_s fhe_mult_mask_s = fhe_builder.build_plain_from_torch( mult_mask_s[self.work_bit]) fhe_ci_mask_s = fhe_builder.build_plain_from_torch( ci_mask_s[self.work_bit]) fhe_sum_xor[self.work_bit] *= fhe_mult_mask_s fhe_sum_xor[self.work_bit] += fhe_ci_mask_s del fhe_mult_mask_s, fhe_ci_mask_s if self.is_shuffle: with NamedTimerInstance("Shuffle"): refresher = EncRefresherServer( self.sum_shape, fhe_builder, self.sub_name("shuffle_refresher")) with NamedTimerInstance("refresh"): new_fhe_sum_xor = refresher.request(fhe_sum_xor) del fhe_sum_xor fhe_sum_xor = self.generate_fhe_shuffled( shuffle_order, new_fhe_sum_xor) del refresher return fhe_sum_xor
def masking_output(self): spread_mask = generate_random_mask( self.modulus, [self.num_output_batch, self.degree]) self.output_mask_s = torch.zeros(self.num_output_unit).double() pod_vector = uIntVector() pt = Plaintext() for idx_output_batch in range(self.num_output_batch): encoding_tensor = torch.zeros(self.degree, dtype=torch.float) for idx_piece in range(self.num_piece_in_batch): idx_output_unit = self.index_output_batch_to_units( idx_output_batch, idx_piece) if idx_output_unit is False: break padded_span = self.num_elem_in_piece start_piece = idx_piece * padded_span arr = spread_mask[idx_output_batch, start_piece:start_piece + padded_span] encoding_tensor[start_piece:start_piece + padded_span] = arr self.output_mask_s[idx_output_unit] = arr.double().sum() encoding_tensor = pmod(encoding_tensor, self.modulus) pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64)) self.batch_encoder.encode(pod_vector, pt) self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch], pt) self.output_mask_s = pmod(self.output_mask_s, self.modulus)
def check_correctness_online(output, input_s, input_c): expected = pmod(output.cuda(), modulus) actual = pmod(input_s.cuda() + input_c.cuda(), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True)
def correctness_fc(self, input_img, output, modulus): x = input_img.cuda().double() x = pmod(torch.mm(x.view(1, -1), self.layers[1].weight.cuda().double().t()).view(-1), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="fc", get_relative=True)
def check_correctness_online(a, b, c_s, c_c): expected = pmod( a.double().to(Config.device) * b.double().to(Config.device), modulus) actual = pmod(c_s + c_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_online", get_relative=True)
def check_correctness_offline(u, v, z_s, z_c): expected = pmod( u.double().to(Config.device) * v.double().to(Config.device), modulus) actual = pmod(z_s + z_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_offline", get_relative=True)
def correctness_relu_only_nn(self, input_img, output, modulus): x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(F.relu(nmod(x, modulus)), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="relu_only_nn", get_relative=True)
def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).to(Config.device) pool = torch.nn.MaxPool2d(2) expected = pool(img.reshape([-1, img_hw, img_hw])).reshape(-1) expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name="maxpool2x2_dgk_online", get_relative=True)
def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True)
def correctness_conv2d(self, input_img, output, modulus): x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, self.layers[1].weight.cuda().double(), padding=1), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="conv2d", get_relative=True)
def check_correctness_online(x, w, b, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() torch_b = b.cuda().double() if b is not None else None expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b) expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True)
def test_conv2d_fhe_ntt_single_thread(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, fhe_builder, "test_conv2d_fhe_ntt", padding) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_conv2d(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() # actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(prot.output_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="test_conv2d_fhe_ntt_single_thread", get_relative=True)
def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).cuda() pool = torch.nn.AvgPool2d(2) expected = pool(img.double().reshape([-1, img_hw, img_hw ])).reshape(-1) * 4 expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name=test_name + "_online", get_relative=True)
def check_correctness_online(x, w, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) if bias is not None: expected += bias.cuda().double() expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True)
def correctness_maxpool2x2(self, input_img, output, modulus): torch_pool1 = torch.nn.MaxPool2d(2) x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(torch_pool1(nmod(x, modulus)), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="maxpool2x2", get_relative=True)
def test_fc_fhe_single_thread(): test_name = "test_fc_fhe_single_thread" print(f"\nTest for {test_name}: Start") modulus = Config.q_23 num_input_unit = 512 num_output_unit = 512 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = FcFheSingleThread(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) print("prot.num_elem_in_piece", prot.num_elem_in_piece) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_with_weight(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(prot.output_shape), modulus) compare_expected_actual(expected, actual, name=test_name, get_relative=True) print(f"\nTest for {test_name}: End")
def mod_div_online(self, z): pre_correct_mod_div_s = torch.where(z < self.nullify_threshold, self.elem_zeros + 1, self.elem_zeros) pre_correct_mod_div_s = pmod( pre_correct_mod_div_s - self.pre_mod_div_c, self.q_23) self.common.pre_corr_mod_s.send(pre_correct_mod_div_s)
def compute_with_weight(self, weight_tensor): assert (weight_tensor.shape == self.weight_shape) pod_vector = uIntVector() pt = Plaintext() self.output_cts = encrypt_zeros(self.num_output_batch, self.batch_encoder, self.encryptor, self.degree) for idx_output_batch, idx_input_batch in product( range(self.num_output_batch), range(self.num_input_batch)): encoding_tensor = torch.zeros(self.degree) is_w_changed = False for idx_piece in range(self.num_piece_in_batch): idx_row, idx_col_start, idx_col_end = \ self.index_weight_batch_to_units(idx_output_batch, idx_input_batch, idx_piece) if idx_row is False: continue is_w_changed = True padded_span = self.num_elem_in_piece data_span = idx_col_end - idx_col_start start_piece = idx_piece * padded_span encoding_tensor[start_piece:start_piece + data_span] = weight_tensor[ idx_row, idx_col_start:idx_col_end] if not is_w_changed: continue encoding_tensor = pmod(encoding_tensor, self.modulus) pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64)) self.batch_encoder.encode(pod_vector, pt) sub_dotted = Ciphertext(self.input_cts[idx_input_batch]) self.evaluator.multiply_plain_inplace(sub_dotted, pt) self.evaluator.add_inplace(self.output_cts[idx_output_batch], sub_dotted)
def shift_by_exp(data, exp, mode="stochastic"): d = (2**-exp) p = modulus x = data # r = torch.zeros_like(x).uniform_(0, p-1).type(torch.int32).float() # r = torch.zeros_like(x).type(torch.int32).float() # n_elem = data.numel() # r = torch.arange(n_elem).cuda().reshape_as(x) # r = torch.from_numpy(np.random.uniform(0, p-1, size=x.numel())).cuda()\ # .type(torch.int32).type(torch.float).reshape(x.size()) n_elem = data.numel() meta_rg = MetaTruncRandomGenerator() rg = meta_rg.get_rg("plain") r = rg.gen_uniform(n_elem, p).cuda().reshape_as(x) x = nmod(x, p) x = F.relu(x) # x = pmod(x, p) # return torch.floor(x/d) psum_xr = pmod(x + r, p) # print("(psum_xr < r):", torch.mean((psum_xr < r).float()).item()) wrapped = nmod(psum_xr // d - r // d + p // d, p) unwrapped = nmod(psum_xr // d - r // d, p) # return unwrapped # x = unwrapped # x = F.relu(x) # return x x = torch.where(psum_xr < r, wrapped, unwrapped) return x
def decomp_to_bit(self, x, res=None): tmp_x = torch.clone(x).to(Config.device) res = torch.zeros([self.work_bit, self.num_elem ]) if res is None else res for i in range(self.work_bit): res[i] = pmod(tmp_x, 2) tmp_x //= 2 return res
def online(self, input_s): input_s = input_s.reshape([1] + list(self.input_shape)).cuda().double() y_s = torch.mm(input_s, self.weight.t()) if self.bias is not None: y_s += self.bias y_s = pmod( y_s.reshape(self.output_shape) - self.output_mask_s, self.modulus) self.output_s = y_s
def test_ntt_conv(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_ntt", padding) y = conv2d_ntt.conv2d(x, w) with NamedTimerInstance("ntt x"): conv2d_ntt.load_and_ntt_x(x) with NamedTimerInstance("ntt w"): conv2d_ntt.load_and_ntt_w(w) with NamedTimerInstance("conv2d"): y = conv2d_ntt.conv2d_loaded() actual = pmod(y, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True, show_where_err=False)
def check_correctness(input_img, output): torch_pool1 = torch.nn.MaxPool2d(2) x = input_img.to(Config.device).double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1), q_23) x = pmod(F.relu(nmod(x, q_23)), q_23) x = pmod(torch_pool1(nmod(x, q_23)), q_23) x = pmod(x // (2**pow_to_div), q_23) x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1), q_23) x = x.view(-1) x = pmod( torch.mm(x.view(1, -1), fc1_w.to(Config.device).double().t()).view(-1), q_23) expected = x actual = pmod(output, q_23) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name=test_name, get_relative=True)
def online(self, input_s): input_s = input_s.reshape([1] + list(self.input_shape)).cuda().double() y_s = F.conv2d(input_s, self.torch_w, padding=self.padding, bias=self.bias) y_s = pmod( y_s.reshape(self.y_shape) - self.output_mask_s, self.modulus) self.output_s = y_s
def check_correctness_mod_div(r, z, correct_mod_div_work_s, correct_mod_div_work_c): elem_zeros = torch.zeros(num_elem).to(Config.device) expected = torch.where(r > z, q_23 // work_range + elem_zeros, elem_zeros) actual = pmod(correct_mod_div_work_s + correct_mod_div_work_c, q_23) compare_expected_actual(expected, actual, get_relative=True, name="mod_div_online")
def conv2d_loaded(self): y = torch.zeros( [self.num_output_channel, self.output_hw, self.output_hw]).double() for i, j in product(range(self.num_output_channel), range(self.num_input_channel)): single_y = self.conv2d_ntted_single_channel( self.ntted_x[j], self.ntted_w[i, j]) y[i, :, :] += self.transform_y_single_channel(single_y) return pmod(y, self.modulus)