def correctness_vgg(self, input_img, output, modulus): return torch_pool1 = torch.nn.MaxPool2d(2) torch_pool2 = torch.nn.MaxPool2d(2) x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, self.layers[1].weight.cuda().double(), padding=1), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = pmod(F.conv2d(x, self.layers[3].weight.cuda().double(), padding=1), modulus) x = pmod(torch_pool1(nmod(x, modulus)), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = pmod(F.conv2d(x, self.layers[6].weight.cuda().double(), padding=1), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = pmod(F.conv2d(x, self.layers[8].weight.cuda().double(), padding=1), modulus) x = pmod(torch_pool2(nmod(x, modulus)), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = pmod(F.conv2d(x, self.layers[11].weight.cuda().double(), padding=1), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = pmod(F.conv2d(x, self.layers[13].weight.cuda().double(), padding=1), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = pmod(F.conv2d(x, self.layers[15].weight.cuda().double(), padding=1), modulus) x = pmod(F.relu(nmod(x, modulus)), modulus) x = x.view(-1) x = pmod(torch.mm(x.view(1, -1), self.layers[18].weight.cuda().double().t()).view(-1), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="minionn_mnist", get_relative=True)
def check_correctness_online(output, input_s, input_c): expected = pmod(output.cuda(), modulus) actual = pmod(input_s.cuda() + input_c.cuda(), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True)
def conv2d(self, x, w): self.load_and_ntt_w(w) self.load_and_ntt_x(x) sub_x = x[0] sub_ntted_x = self.ntted_x[0] inv_sub_ntted_x = self.ntt_matmul.intt2d(sub_ntted_x.double()) trun_inv_sub_ntted_x = inv_sub_ntted_x[:self.img_hw, :self.img_hw] compare_expected_actual(pmod(sub_x, self.modulus), pmod(trun_inv_sub_ntted_x, self.modulus), get_relative=True, name="sub_x") sub_w = w[0, 0] sub_ntted_w = self.ntted_w[0, 0] inv_sub_ntted_w = self.ntt_matmul.intt2d(sub_ntted_w.double()) trun_inv_sub_ntted_w = inv_sub_ntted_w[:self.filter_hw, :self. filter_hw] expected = pmod(sub_w, self.modulus).rot90(2) actual = pmod(trun_inv_sub_ntted_w, self.modulus) compare_expected_actual(expected, actual, get_relative=True, name="sub_w") dotted = self.conv2d_ntted_single_channel(sub_ntted_x, sub_ntted_w) sub_y = self.transform_y_single_channel(dotted) # sub_w = w[0, 0] # sub_ntted_w = self.ntted_w[0] return self.conv2d_loaded()
def test_torch_ntt(): modulus = 786433 img_hw = 64 filter_hw = 3 padding = 1 data_bit = 17 len_vector = img_hw data_range = 2**data_bit root, mod, ntt_mat, inv_mat = generate_ntt_param(modulus, len_vector, data_range) x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int) w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int) x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, img_hw**2).reshape([img_hw, img_hw]).double() w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, filter_hw**2).reshape([filter_hw, filter_hw]).double() ntt_mat = ntt_mat.double() inv_mat = inv_mat.double() with NamedTimerInstance("Mat NTT 2d"): ntted = ntt_mat2d(ntt_mat, mod, x) reved = ntt_mat2d(inv_mat, mod, ntted) expected = pmod(x, modulus).type(torch.int) actual = pmod(reved, modulus).type(torch.int) compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def test_client(): init_communicate(Config.client_rank) fhe_builder = FheBuilder(modulus, degree) comm_fhe_builder = CommFheBuilder(Config.client_rank, fhe_builder, "fhe_builder") fhe_builder.generate_keys() comm_fhe_builder.send_public_key() blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag) blob_ori.prepare_recv() dec = blob_ori.get_recv_decrypt() compare_expected_actual(ori, dec, get_relative=True, name=ori_tag) enc_pk = fhe_builder.build_enc(num_elem) comm_fhe_builder.recv_enc(enc_pk, pk_tag) comm_fhe_builder.wait_enc(pk_tag) actual_tensor_pk = fhe_builder.decrypt_to_torch(enc_pk) compare_expected_actual(expected_tensor_pk, actual_tensor_pk, get_relative=True, name="Recovering to test pk") comm_fhe_builder.send_secret_key() enc_sk = fhe_builder.build_enc_from_torch(expected_tensor_sk) comm_fhe_builder.send_enc(enc_sk, sk_tag) dist.destroy_process_group()
def test_fhe_builder_rebuild(): print() print("Test for FheBuilder rebuild: start") modulus, degree = 12289, 2048 num_elem = 2 ** 17 print(f"modulus: {modulus}, degree: {degree}") fhe_builder_client = FheBuilder(modulus, degree) fhe_builder_server = FheBuilder(modulus, degree) fhe_builder_client.generate_keys() tensor_server = torch.ones(num_elem).float() + 5 fhe_builder_server.get_public_key_buffer().copy_(fhe_builder_client.get_public_key_buffer()) fhe_builder_server.build_from_loaded_public_key() enc_pk = fhe_builder_server.build_enc_from_torch(tensor_server) tensor_client = fhe_builder_client.decrypt_to_torch(enc_pk) compare_expected_actual(tensor_server, tensor_client, get_relative=True, name="Rebuilding public key") tensor_client_sk = torch.ones(num_elem).float() + 12 fhe_builder_server.get_secret_key_buffer().copy_(fhe_builder_client.get_secret_key_buffer()) fhe_builder_server.build_from_loaded_secret_key() enc_sk = fhe_builder_client.build_enc_from_torch(tensor_client_sk) tensor_server_sk = fhe_builder_server.decrypt_to_torch(enc_sk) compare_expected_actual(tensor_client_sk, tensor_server_sk, get_relative=True, name="Rebuilding secret key") print("Test for FheBuilder rebuild: end") print()
def check_correctness(input_img, output): torch_pool1 = torch.nn.MaxPool2d(2) x = input_img.to(Config.device).double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1), q_23) x = pmod(F.relu(nmod(x, q_23)), q_23) x = pmod(torch_pool1(nmod(x, q_23)), q_23) x = pmod(x // (2**pow_to_div), q_23) x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1), q_23) x = x.view(-1) x = pmod( torch.mm(x.view(1, -1), fc1_w.to(Config.device).double().t()).view(-1), q_23) expected = x actual = pmod(output, q_23) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name=test_name, get_relative=True)
def check_correctness_offline(u, v, z_s, z_c): expected = pmod( u.double().to(Config.device) * v.double().to(Config.device), modulus) actual = pmod(z_s + z_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_offline", get_relative=True)
def correctness_fc(self, input_img, output, modulus): x = input_img.cuda().double() x = pmod(torch.mm(x.view(1, -1), self.layers[1].weight.cuda().double().t()).view(-1), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="fc", get_relative=True)
def check_correctness_online(a, b, c_s, c_c): expected = pmod( a.double().to(Config.device) * b.double().to(Config.device), modulus) actual = pmod(c_s + c_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_online", get_relative=True)
def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True)
def check_correctness(x, y, dgk_x_leq_y_s, dgk_x_leq_y_c): x = torch.where(x < q_23 // 2, x, x - q_23).to(Config.device) y = torch.where(y < q_23 // 2, y, y - q_23).to(Config.device) expected_x_leq_y = (x <= y) dgk_x_leq_y_recon = pmod(dgk_x_leq_y_s + dgk_x_leq_y_c, q_23) compare_expected_actual(expected_x_leq_y, dgk_x_leq_y_recon, name="DGK x <= y", get_relative=True) print(torch.sum(expected_x_leq_y != dgk_x_leq_y_recon))
def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).to(Config.device) pool = torch.nn.MaxPool2d(2) expected = pool(img.reshape([-1, img_hw, img_hw])).reshape(-1) expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name="maxpool2x2_dgk_online", get_relative=True)
def check_correctness_mod_div(r, z, correct_mod_div_work_s, correct_mod_div_work_c): elem_zeros = torch.zeros(num_elem).to(Config.device) expected = torch.where(r > z, q_23 // work_range + elem_zeros, elem_zeros) actual = pmod(correct_mod_div_work_s + correct_mod_div_work_c, q_23) compare_expected_actual(expected, actual, get_relative=True, name="mod_div_online")
def correctness_relu_only_nn(self, input_img, output, modulus): x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(F.relu(nmod(x, modulus)), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="relu_only_nn", get_relative=True)
def correctness_conv2d(self, input_img, output, modulus): x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, self.layers[1].weight.cuda().double(), padding=1), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="conv2d", get_relative=True)
def test_conv2d_fhe_ntt_single_thread(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, fhe_builder, "test_conv2d_fhe_ntt", padding) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_conv2d(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() # actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(prot.output_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="test_conv2d_fhe_ntt_single_thread", get_relative=True)
def check_correctness_online(x, w, b, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() torch_b = b.cuda().double() if b is not None else None expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b) expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True)
def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).cuda() pool = torch.nn.AvgPool2d(2) expected = pool(img.double().reshape([-1, img_hw, img_hw ])).reshape(-1) * 4 expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name=test_name + "_online", get_relative=True)
def test_1d_server(): init_communicate(Config.server_rank) shape = num_elem tensor = gen_unirand_int_grain(0, modulus, shape) refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher") enc = fhe_builder.build_enc_from_torch(tensor) refreshed = refresher.request(enc) tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed) compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh") end_communicate()
def test_fc_fhe_single_thread(): test_name = "test_fc_fhe_single_thread" print(f"\nTest for {test_name}: Start") modulus = Config.q_23 num_input_unit = 512 num_output_unit = 512 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = FcFheSingleThread(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) print("prot.num_elem_in_piece", prot.num_elem_in_piece) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_with_weight(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(prot.output_shape), modulus) compare_expected_actual(expected, actual, name=test_name, get_relative=True) print(f"\nTest for {test_name}: End")
def correctness_maxpool2x2(self, input_img, output, modulus): torch_pool1 = torch.nn.MaxPool2d(2) x = input_img.cuda().double() x = x.reshape([1] + list(x.shape)) x = pmod(torch_pool1(nmod(x, modulus)), modulus) expected = x actual = pmod(output, modulus) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name="maxpool2x2", get_relative=True)
def check_correctness_online(x, w, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) if bias is not None: expected += bias.cuda().double() expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True)
def comm_base_server(): init_communicate(Config.server_rank) comm_base = CommBase(Config.server_rank, comm_name) send_int16 = expect_int16.cuda() send_uint8 = expect_uint8.cuda() with NamedTimerInstance("Server float and int16"): comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag) comm_base.send_torch(send_int16, int16_tag) comm_base.wait(float_tag) actual_float = comm_base.get_tensor(float_tag).cuda() comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag) dist.barrier() with NamedTimerInstance("Server float and int16"): comm_base.send_torch(send_int16, int16_tag) comm_base.wait(float_tag) actual_float = comm_base.get_tensor(float_tag).cuda() comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag) dist.barrier() with NamedTimerInstance("Server double and uint8"): comm_base.send_torch(send_uint8, uint8_tag) comm_base.wait(double_tag) actual_double = comm_base.get_tensor(double_tag).cuda() comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag) dist.barrier() with NamedTimerInstance("Server double and uint8"): comm_base.send_torch(send_uint8, uint8_tag) comm_base.wait(double_tag) actual_double = comm_base.get_tensor(double_tag).cuda() dist.barrier() compare_expected_actual(expect_float.cuda(), actual_float, name="float", get_relative=True) compare_expected_actual(expect_double.cuda(), actual_double, name="double", get_relative=True) google_vm_simulator = NetworkSimulator(bandwidth=10 * (10**9), basis_latency=.001) with NamedTimerInstance("Simulate int16"): google_vm_simulator.simulate(send_int16.cpu().cuda()) with NamedTimerInstance("Simulate uint8"): google_vm_simulator.simulate(send_uint8.cpu().cuda()) dist.destroy_process_group()
def comm_base_client(): init_communicate(Config.client_rank) comm_base = CommBase(Config.client_rank, comm_name) send_float = expect_float.cuda() send_double = expect_double.cuda() with NamedTimerInstance("Client float and int16"): comm_base.recv_torch( torch.zeros(num_elem).type(torch.int16), int16_tag) comm_base.send_torch(send_float, float_tag) comm_base.wait(int16_tag) actual_int16 = comm_base.get_tensor(int16_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.int16), int16_tag) dist.barrier() with NamedTimerInstance("Client float and int16"): comm_base.send_torch(send_float, float_tag) comm_base.wait(int16_tag) actual_int16 = comm_base.get_tensor(int16_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.uint8), uint8_tag) dist.barrier() with NamedTimerInstance("Client double and uint8"): comm_base.send_torch(send_double, double_tag) comm_base.wait(uint8_tag) actual_uint8 = comm_base.get_tensor(uint8_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.uint8), uint8_tag) dist.barrier() with NamedTimerInstance("Client double and uint8"): comm_base.send_torch(send_double, double_tag) comm_base.wait(uint8_tag) actual_uint8 = comm_base.get_tensor(uint8_tag).cuda() dist.barrier() compare_expected_actual(expect_int16.cuda(), actual_int16, name="int16", get_relative=True) compare_expected_actual(expect_uint8.cuda(), actual_uint8, name="uint8", get_relative=True) dist.destroy_process_group()
def test_fhe_enc_copy(): print() print("Test for FheEncTensor Copy: start") modulus, degree = 12289, 2048 num_elem = 2 ** 17 print(f"modulus: {modulus}, degree: {degree}") fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() tensor = torch.ones(num_elem).float() + 5 enc_old = fhe_builder.build_enc_from_torch(tensor) enc_new = enc_old.copy() actual = fhe_builder.decrypt_to_torch(enc_new) compare_expected_actual(tensor, actual, get_relative=True, name="fhe enc copy") print("Test for FheEncTensor Copy: end") print()
def test_nnt_conv_single_channel(): modulus = 786433 img_hw = 6 filter_hw = 3 padding = 1 data_bit = 17 data_range = 2**data_bit conv_hw = img_hw + 2 * padding padded_hw = get_pow_2_ceil(conv_hw) output_hw = img_hw + 2 * padding - (filter_hw - 1) output_offset = filter_hw - 2 x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, img_hw**2).reshape([img_hw, img_hw ]).numpy().astype(np.int) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, filter_hw**2).reshape([filter_hw, filter_hw ]).numpy().astype(np.int) x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int) w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int) padded_x = pad_to_size(x, padded_hw) padded_w = pad_to_size(np.rot90(w, 2), padded_hw) print(padded_x) print(padded_w) with NamedTimerInstance("NTT2D, Sympy"): ntted_x = transform2d( padded_x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) ntted_w = transform2d( padded_w, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) with NamedTimerInstance("Point-wise Dot"): doted = ntted_x * ntted_w with NamedTimerInstance("iNTT2D"): reved = transform2d( doted, lambda sub_img: intt(sub_img.tolist(), prime=modulus)) actual = reved[output_offset:output_hw + output_offset, output_offset:output_hw + output_offset] print("reved\n", reved) print("actual\n", actual) torch_x = torch.tensor(x).reshape([1, 1, img_hw, img_hw]) torch_w = torch.tensor(w).reshape([1, 1, filter_hw, filter_hw]) expected = F.conv2d(torch_x, torch_w, padding=1) expected = pmod(expected.reshape(output_hw, output_hw), modulus) print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def test_ntt_conv(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_ntt", padding) y = conv2d_ntt.conv2d(x, w) with NamedTimerInstance("ntt x"): conv2d_ntt.load_and_ntt_x(x) with NamedTimerInstance("ntt w"): conv2d_ntt.load_and_ntt_w(w) with NamedTimerInstance("conv2d"): y = conv2d_ntt.conv2d_loaded() actual = pmod(y, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True, show_where_err=False)
def check_correctness(self, input_img, output, modulus): plain_net = get_plain_net() expected = plain_net( input_img.reshape([1] + list(input_img.shape)).cuda()) expected = nmod(expected.reshape(expected.shape[1:]), modulus) actual = nmod(output, modulus).cuda() print("expected", expected) print("actual", actual) compare_expected_actual(expected, actual, name="secure_vgg", get_relative=True) _, expected_max = torch.max(expected, 0) _, actual_max = torch.max(actual, 0) print( f"expected_max: {expected_max}, actual_max: {actual_max}, Match: {expected_max == actual_max}" )
def test_noise(): print() print("Test for FheBuilder: start") modulus, degree = 12289, 2048 num_elem = 2 ** 14 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() print(f"modulus: {modulus}, degree: {degree}") print() gpu_tensor = gen_unirand_int_grain(0, 2, num_elem) print(gpu_tensor) gpu_tensor_rev = gen_unirand_int_grain(0, modulus - 1, num_elem) with NamedTimerInstance(f"build_plain_from_torch with num_elem: {num_elem}"): plain = fhe_builder.build_plain_from_torch(gpu_tensor) with NamedTimerInstance(f"plain.export_as_torch_gpu() with num_elem: {num_elem}"): tensor_from_plain = plain.export_as_torch() print("Fhe Plain encrypt and decrypt: ", end="") assert(compare_expected_actual(gpu_tensor, tensor_from_plain, verbose=True, get_relative=True).RelAvgDiff == 0) print() with NamedTimerInstance(f"fhe_builder.build_enc with num_elem: {num_elem}"): cipher = fhe_builder.build_enc(num_elem) with NamedTimerInstance(f"cipher.encrypt_additive with num_elem: {num_elem}"): cipher.encrypt_additive(gpu_tensor) with NamedTimerInstance(f"fhe_builder.decrypt_to_torch with num_elem: {num_elem}"): tensor_from_cipher = fhe_builder.decrypt_to_torch(cipher) print("Fhe Enc encrypt and decrypt: ", end="") assert(compare_expected_actual(gpu_tensor, tensor_from_cipher, verbose=True, get_relative=True).RelAvgDiff == 0) print() pt = fhe_builder.build_plain_from_torch(gpu_tensor) ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev) fhe_builder.noise_budget(ct, name="before ep") with NamedTimerInstance(f"EP Mult with num_elem: {num_elem}"): ct *= pt fhe_builder.noise_budget(ct, name="after ep") expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus) actual = fhe_builder.decrypt_to_torch(ct) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() print("Test for FheBuilder: Finish")