Exemple #1
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()
Exemple #2
0
def test_torch_ntt():
    modulus = 786433
    img_hw = 64
    filter_hw = 3
    padding = 1
    data_bit = 17

    len_vector = img_hw
    data_range = 2**data_bit
    root, mod, ntt_mat, inv_mat = generate_ntt_param(modulus, len_vector,
                                                     data_range)

    x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int)
    w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int)
    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              img_hw**2).reshape([img_hw, img_hw]).double()
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              filter_hw**2).reshape([filter_hw,
                                                     filter_hw]).double()

    ntt_mat = ntt_mat.double()
    inv_mat = inv_mat.double()

    with NamedTimerInstance("Mat NTT 2d"):
        ntted = ntt_mat2d(ntt_mat, mod, x)
    reved = ntt_mat2d(inv_mat, mod, ntted)
    expected = pmod(x, modulus).type(torch.int)
    actual = pmod(reved, modulus).type(torch.int)
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
Exemple #3
0
    def conv2d(self, x, w):
        self.load_and_ntt_w(w)
        self.load_and_ntt_x(x)

        sub_x = x[0]
        sub_ntted_x = self.ntted_x[0]
        inv_sub_ntted_x = self.ntt_matmul.intt2d(sub_ntted_x.double())
        trun_inv_sub_ntted_x = inv_sub_ntted_x[:self.img_hw, :self.img_hw]
        compare_expected_actual(pmod(sub_x, self.modulus),
                                pmod(trun_inv_sub_ntted_x, self.modulus),
                                get_relative=True,
                                name="sub_x")

        sub_w = w[0, 0]
        sub_ntted_w = self.ntted_w[0, 0]
        inv_sub_ntted_w = self.ntt_matmul.intt2d(sub_ntted_w.double())
        trun_inv_sub_ntted_w = inv_sub_ntted_w[:self.filter_hw, :self.
                                               filter_hw]
        expected = pmod(sub_w, self.modulus).rot90(2)
        actual = pmod(trun_inv_sub_ntted_w, self.modulus)
        compare_expected_actual(expected,
                                actual,
                                get_relative=True,
                                name="sub_w")

        dotted = self.conv2d_ntted_single_channel(sub_ntted_x, sub_ntted_w)
        sub_y = self.transform_y_single_channel(dotted)

        # sub_w = w[0, 0]
        # sub_ntted_w = self.ntted_w[0]
        return self.conv2d_loaded()
Exemple #4
0
    def online(self, input_s):
        input_s = input_s.to(self.comp_device)
        masked_input_s = pmod(input_s + self.input_mask_s, self.modulus)
        self.blob_masked_input_s.send(masked_input_s)

        masked_output_s = self.blob_masked_output_s.get_recv()
        self.output_s = pmod(masked_output_s - self.input_mask_s, self.modulus)
Exemple #5
0
    def sum_c_i_offline(self, delta_a, fhe_beta_i_c, fhe_alpha_beta_xor_c, s,
                        alpha_i, ci_mask_s, mult_mask_s, shuffle_order):
        # the last row of sum_xor is c_{-1}, which helps check the case with x == y
        fhe_builder = self.fhe_builder_16
        # fhe_sum_xor = [fhe_builder.build_enc(self.num_elem) for i in range(self.num_work_batch)]
        fhe_sum_xor = [None for i in range(self.num_work_batch)]
        fhe_sum_xor[self.work_bit - 1] = fhe_builder.build_enc(self.num_elem)
        for i in range(self.work_bit - 1)[::-1]:
            fhe_sum_xor[i] = fhe_sum_xor[i + 1].copy()
            fhe_sum_xor[i] += fhe_alpha_beta_xor_c[i + 1]
        fhe_delta_a = fhe_builder.build_plain_from_torch(delta_a)
        fhe_sum_xor[self.work_bit] = fhe_sum_xor[0].copy()
        fhe_sum_xor[self.work_bit] += fhe_alpha_beta_xor_c[0]
        fhe_sum_xor[self.work_bit] += fhe_delta_a
        del fhe_delta_a

        for i in range(self.work_bit)[::-1]:
            fhe_mult_3 = fhe_builder.build_plain_from_torch(
                pmod(3 * mult_mask_s[i].cpu(), self.q_16))
            fhe_mult_mask_s = fhe_builder.build_plain_from_torch(
                mult_mask_s[i])
            masked_s = pmod(
                s.type(torch.int64) * mult_mask_s[i].type(torch.int64),
                self.q_16).type(torch.float32)
            # print("s * mult_mask_s[i]", torch.max(masked_s))
            fhe_s = fhe_builder.build_plain_from_torch(masked_s)
            fhe_alpha_i = fhe_builder.build_plain_from_torch(alpha_i[i] *
                                                             mult_mask_s[i])
            fhe_ci_mask_s = fhe_builder.build_plain_from_torch(ci_mask_s[i])
            fhe_beta_i_c[i] *= fhe_mult_mask_s
            fhe_sum_xor[i] *= fhe_mult_3
            fhe_sum_xor[i] -= fhe_beta_i_c[i]
            fhe_sum_xor[i] += fhe_s
            fhe_sum_xor[i] += fhe_alpha_i
            fhe_sum_xor[i] += fhe_ci_mask_s

            del fhe_mult_3, fhe_mult_mask_s, fhe_s, fhe_alpha_i, fhe_ci_mask_s

        fhe_mult_mask_s = fhe_builder.build_plain_from_torch(
            mult_mask_s[self.work_bit])
        fhe_ci_mask_s = fhe_builder.build_plain_from_torch(
            ci_mask_s[self.work_bit])
        fhe_sum_xor[self.work_bit] *= fhe_mult_mask_s
        fhe_sum_xor[self.work_bit] += fhe_ci_mask_s

        del fhe_mult_mask_s, fhe_ci_mask_s

        if self.is_shuffle:
            with NamedTimerInstance("Shuffle"):
                refresher = EncRefresherServer(
                    self.sum_shape, fhe_builder,
                    self.sub_name("shuffle_refresher"))
                with NamedTimerInstance("refresh"):
                    new_fhe_sum_xor = refresher.request(fhe_sum_xor)
                del fhe_sum_xor
                fhe_sum_xor = self.generate_fhe_shuffled(
                    shuffle_order, new_fhe_sum_xor)
                del refresher

        return fhe_sum_xor
Exemple #6
0
    def masking_output(self):
        spread_mask = generate_random_mask(
            self.modulus, [self.num_output_batch, self.degree])
        self.output_mask_s = torch.zeros(self.num_output_unit).double()

        pod_vector = uIntVector()
        pt = Plaintext()
        for idx_output_batch in range(self.num_output_batch):
            encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
            for idx_piece in range(self.num_piece_in_batch):
                idx_output_unit = self.index_output_batch_to_units(
                    idx_output_batch, idx_piece)
                if idx_output_unit is False:
                    break
                padded_span = self.num_elem_in_piece
                start_piece = idx_piece * padded_span
                arr = spread_mask[idx_output_batch,
                                  start_piece:start_piece + padded_span]
                encoding_tensor[start_piece:start_piece + padded_span] = arr
                self.output_mask_s[idx_output_unit] = arr.double().sum()
            encoding_tensor = pmod(encoding_tensor, self.modulus)
            pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
            self.batch_encoder.encode(pod_vector, pt)
            self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch],
                                             pt)

        self.output_mask_s = pmod(self.output_mask_s, self.modulus)
Exemple #7
0
 def check_correctness_online(output, input_s, input_c):
     expected = pmod(output.cuda(), modulus)
     actual = pmod(input_s.cuda() + input_c.cuda(), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " online",
                             get_relative=True)
Exemple #8
0
def correctness_fc(self, input_img, output, modulus):
    x = input_img.cuda().double()
    x = pmod(torch.mm(x.view(1, -1), self.layers[1].weight.cuda().double().t()).view(-1), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="fc", get_relative=True)
Exemple #9
0
 def check_correctness_online(a, b, c_s, c_c):
     expected = pmod(
         a.double().to(Config.device) * b.double().to(Config.device),
         modulus)
     actual = pmod(c_s + c_c, modulus)
     compare_expected_actual(expected,
                             actual,
                             name="shares_mult_online",
                             get_relative=True)
Exemple #10
0
 def check_correctness_offline(u, v, z_s, z_c):
     expected = pmod(
         u.double().to(Config.device) * v.double().to(Config.device),
         modulus)
     actual = pmod(z_s + z_c, modulus)
     compare_expected_actual(expected,
                             actual,
                             name="shares_mult_offline",
                             get_relative=True)
Exemple #11
0
def correctness_relu_only_nn(self, input_img, output, modulus):
    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(F.relu(nmod(x, modulus)), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="relu_only_nn", get_relative=True)
Exemple #12
0
 def check_correctness_online(img, max_s, max_c):
     img = torch.where(img < q_23 // 2, img, img - q_23).to(Config.device)
     pool = torch.nn.MaxPool2d(2)
     expected = pool(img.reshape([-1, img_hw, img_hw])).reshape(-1)
     expected = pmod(expected, q_23)
     actual = pmod(max_s + max_c, q_23)
     compare_expected_actual(expected,
                             actual,
                             name="maxpool2x2_dgk_online",
                             get_relative=True)
Exemple #13
0
 def check_correctness_offline(x, w, output_mask, output_c):
     actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
     torch_x = x.reshape([1] + x_shape).cuda().double()
     torch_w = w.reshape(w_shape).cuda().double()
     expected = F.conv2d(torch_x, torch_w, padding=padding)
     expected = pmod(expected.reshape(output_mask.shape), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " offline",
                             get_relative=True)
Exemple #14
0
def correctness_conv2d(self, input_img, output, modulus):
    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(F.conv2d(x, self.layers[1].weight.cuda().double(), padding=1), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="conv2d", get_relative=True)
Exemple #15
0
 def check_correctness_online(x, w, b, output_s, output_c):
     actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
     torch_x = x.reshape([1] + x_shape).cuda().double()
     torch_w = w.reshape(w_shape).cuda().double()
     torch_b = b.cuda().double() if b is not None else None
     expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
     expected = pmod(expected.reshape(output_s.shape), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " online",
                             get_relative=True)
Exemple #16
0
def test_conv2d_fhe_ntt_single_thread():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw,
                                    num_input_channel, num_output_channel,
                                    fhe_builder, "test_conv2d_fhe_ntt",
                                    padding)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_conv2d(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    # actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="test_conv2d_fhe_ntt_single_thread",
                            get_relative=True)
Exemple #17
0
 def check_correctness_online(img, max_s, max_c):
     img = torch.where(img < q_23 // 2, img, img - q_23).cuda()
     pool = torch.nn.AvgPool2d(2)
     expected = pool(img.double().reshape([-1, img_hw, img_hw
                                           ])).reshape(-1) * 4
     expected = pmod(expected, q_23)
     actual = pmod(max_s + max_c, q_23)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + "_online",
                             get_relative=True)
Exemple #18
0
 def check_correctness_online(x, w, output_s, output_c):
     actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
     torch_x = x.reshape([1] + x_shape).cuda().double()
     torch_w = w.reshape(w_shape).cuda().double()
     expected = torch.mm(torch_x, torch_w.t())
     if bias is not None:
         expected += bias.cuda().double()
     expected = pmod(expected.reshape(output_s.shape), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " online",
                             get_relative=True)
Exemple #19
0
def correctness_maxpool2x2(self, input_img, output, modulus):
    torch_pool1 = torch.nn.MaxPool2d(2)

    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(torch_pool1(nmod(x, modulus)), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="maxpool2x2", get_relative=True)
Exemple #20
0
def test_fc_fhe_single_thread():
    test_name = "test_fc_fhe_single_thread"
    print(f"\nTest for {test_name}: Start")
    modulus = Config.q_23
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = FcFheSingleThread(modulus, data_range, num_input_unit,
                             num_output_unit, fhe_builder, test_name)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)
    print("prot.num_elem_in_piece", prot.num_elem_in_piece)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_with_weight(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    compare_expected_actual(expected,
                            actual,
                            name=test_name,
                            get_relative=True)
    print(f"\nTest for {test_name}: End")
Exemple #21
0
 def mod_div_online(self, z):
     pre_correct_mod_div_s = torch.where(z < self.nullify_threshold,
                                         self.elem_zeros + 1,
                                         self.elem_zeros)
     pre_correct_mod_div_s = pmod(
         pre_correct_mod_div_s - self.pre_mod_div_c, self.q_23)
     self.common.pre_corr_mod_s.send(pre_correct_mod_div_s)
Exemple #22
0
 def compute_with_weight(self, weight_tensor):
     assert (weight_tensor.shape == self.weight_shape)
     pod_vector = uIntVector()
     pt = Plaintext()
     self.output_cts = encrypt_zeros(self.num_output_batch,
                                     self.batch_encoder, self.encryptor,
                                     self.degree)
     for idx_output_batch, idx_input_batch in product(
             range(self.num_output_batch), range(self.num_input_batch)):
         encoding_tensor = torch.zeros(self.degree)
         is_w_changed = False
         for idx_piece in range(self.num_piece_in_batch):
             idx_row, idx_col_start, idx_col_end = \
                 self.index_weight_batch_to_units(idx_output_batch, idx_input_batch, idx_piece)
             if idx_row is False:
                 continue
             is_w_changed = True
             padded_span = self.num_elem_in_piece
             data_span = idx_col_end - idx_col_start
             start_piece = idx_piece * padded_span
             encoding_tensor[start_piece:start_piece +
                             data_span] = weight_tensor[
                                 idx_row, idx_col_start:idx_col_end]
         if not is_w_changed:
             continue
         encoding_tensor = pmod(encoding_tensor, self.modulus)
         pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
         self.batch_encoder.encode(pod_vector, pt)
         sub_dotted = Ciphertext(self.input_cts[idx_input_batch])
         self.evaluator.multiply_plain_inplace(sub_dotted, pt)
         self.evaluator.add_inplace(self.output_cts[idx_output_batch],
                                    sub_dotted)
Exemple #23
0
def shift_by_exp(data, exp, mode="stochastic"):
    d = (2**-exp)

    p = modulus
    x = data

    # r = torch.zeros_like(x).uniform_(0, p-1).type(torch.int32).float()
    # r = torch.zeros_like(x).type(torch.int32).float()
    # n_elem = data.numel()
    # r = torch.arange(n_elem).cuda().reshape_as(x)

    # r = torch.from_numpy(np.random.uniform(0, p-1, size=x.numel())).cuda()\
    #     .type(torch.int32).type(torch.float).reshape(x.size())

    n_elem = data.numel()
    meta_rg = MetaTruncRandomGenerator()
    rg = meta_rg.get_rg("plain")
    r = rg.gen_uniform(n_elem, p).cuda().reshape_as(x)

    x = nmod(x, p)
    x = F.relu(x)
    # x = pmod(x, p)
    # return torch.floor(x/d)
    psum_xr = pmod(x + r, p)
    # print("(psum_xr < r):", torch.mean((psum_xr < r).float()).item())
    wrapped = nmod(psum_xr // d - r // d + p // d, p)
    unwrapped = nmod(psum_xr // d - r // d, p)
    # return unwrapped
    # x = unwrapped
    # x = F.relu(x)
    # return x
    x = torch.where(psum_xr < r, wrapped, unwrapped)

    return x
Exemple #24
0
 def decomp_to_bit(self, x, res=None):
     tmp_x = torch.clone(x).to(Config.device)
     res = torch.zeros([self.work_bit, self.num_elem
                        ]) if res is None else res
     for i in range(self.work_bit):
         res[i] = pmod(tmp_x, 2)
         tmp_x //= 2
     return res
Exemple #25
0
 def online(self, input_s):
     input_s = input_s.reshape([1] + list(self.input_shape)).cuda().double()
     y_s = torch.mm(input_s, self.weight.t())
     if self.bias is not None:
         y_s += self.bias
     y_s = pmod(
         y_s.reshape(self.output_shape) - self.output_mask_s, self.modulus)
     self.output_s = y_s
Exemple #26
0
def test_ntt_conv():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw,
                           num_input_channel, num_output_channel,
                           "test_conv2d_ntt", padding)
    y = conv2d_ntt.conv2d(x, w)

    with NamedTimerInstance("ntt x"):
        conv2d_ntt.load_and_ntt_x(x)
    with NamedTimerInstance("ntt w"):
        conv2d_ntt.load_and_ntt_w(w)
    with NamedTimerInstance("conv2d"):
        y = conv2d_ntt.conv2d_loaded()
    actual = pmod(y, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="ntt",
                            get_relative=True,
                            show_where_err=False)
Exemple #27
0
    def check_correctness(input_img, output):
        torch_pool1 = torch.nn.MaxPool2d(2)

        x = input_img.to(Config.device).double()
        x = x.reshape([1] + list(x.shape))
        x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1),
                 q_23)
        x = pmod(F.relu(nmod(x, q_23)), q_23)
        x = pmod(torch_pool1(nmod(x, q_23)), q_23)
        x = pmod(x // (2**pow_to_div), q_23)
        x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1),
                 q_23)
        x = x.view(-1)
        x = pmod(
            torch.mm(x.view(1, -1),
                     fc1_w.to(Config.device).double().t()).view(-1), q_23)

        expected = x
        actual = pmod(output, q_23)
        if len(expected.shape) == 4 and expected.shape[0] == 1:
            expected = expected.reshape(expected.shape[1:])
        compare_expected_actual(expected,
                                actual,
                                name=test_name,
                                get_relative=True)
Exemple #28
0
 def online(self, input_s):
     input_s = input_s.reshape([1] + list(self.input_shape)).cuda().double()
     y_s = F.conv2d(input_s,
                    self.torch_w,
                    padding=self.padding,
                    bias=self.bias)
     y_s = pmod(
         y_s.reshape(self.y_shape) - self.output_mask_s, self.modulus)
     self.output_s = y_s
Exemple #29
0
 def check_correctness_mod_div(r, z, correct_mod_div_work_s,
                               correct_mod_div_work_c):
     elem_zeros = torch.zeros(num_elem).to(Config.device)
     expected = torch.where(r > z, q_23 // work_range + elem_zeros,
                            elem_zeros)
     actual = pmod(correct_mod_div_work_s + correct_mod_div_work_c, q_23)
     compare_expected_actual(expected,
                             actual,
                             get_relative=True,
                             name="mod_div_online")
Exemple #30
0
 def conv2d_loaded(self):
     y = torch.zeros(
         [self.num_output_channel, self.output_hw,
          self.output_hw]).double()
     for i, j in product(range(self.num_output_channel),
                         range(self.num_input_channel)):
         single_y = self.conv2d_ntted_single_channel(
             self.ntted_x[j], self.ntted_w[i, j])
         y[i, :, :] += self.transform_y_single_channel(single_y)
     return pmod(y, self.modulus)