def test_perf_ckks_conv_eval(benchmark, image_shape): ctx = ckks_context() img = np.random.randn(*image_shape) kernel_shape = [7, 7] stride = 3 conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3) conv1_weight = conv1.weight.data.view(conv1.out_channels, conv1.kernel_size[0], conv1.kernel_size[1]).tolist() conv1_bias = conv1.bias.data.tolist() x_enc, windows_nb = ts.im2col_encoding(ctx, img.reshape(image_shape).tolist(), kernel_shape[0], kernel_shape[1], stride) def op(): enc_channels = [] for kernel, bias in zip(conv1_weight, conv1_bias): y = x_enc.conv2d_im2col(kernel, windows_nb) + bias enc_channels.append(y) ts.CKKSVector.pack_vectors(enc_channels) benchmark.pedantic( op, rounds=rounds, iterations=iterations, )
def test_conv2d_im2col_inplace(context, input_size, kernel_size, stride): def generate_input(input_size, kernel_size, stride): # generated random values and prepare the inputs x = np.random.randn(input_size, input_size) kernel = np.random.randn(kernel_size, kernel_size) out_h, out_w = ( (x.shape[0] - kernel.shape[0]) // stride + 1, (x.shape[1] - kernel.shape[1]) // stride + 1, ) padded_im2col_x = view_as_windows(x, kernel.shape, step=stride) padded_im2col_x = padded_im2col_x.reshape(out_h * out_w, kernel.shape[0] * kernel.shape[1]) next_power2 = pow(2, math.ceil(math.log2(kernel.size))) pad_width = next_power2 - kernel.size padded_im2col_x = np.pad(padded_im2col_x, ((0, 0), (0, pad_width))) padded_kernel = np.pad(kernel.flatten(), (0, pad_width)) return x, padded_im2col_x, kernel, padded_kernel # generated galois keys in order to do rotation on ciphertext vectors context.generate_galois_keys() x, padded_im2col_x, kernel, padded_kernel = generate_input(input_size, kernel_size, stride) # windows_nb = padded_im2col_x.shape[0] x_enc, windows_nb = ts.im2col_encoding(context, x, kernel.shape[0], kernel.shape[1], stride) x_enc.conv2d_im2col_(kernel.tolist(), windows_nb) decrypted_result = x_enc.decrypt() expected = (padded_im2col_x @ padded_kernel).tolist() assert _almost_equal(decrypted_result, expected, 0)
def encrypted_model(tensor): x_enc, windows_nb = ts.im2col_encoding(context, tensor.view(28, 28).tolist(), kernel_shape[0], kernel_shape[1], stride) # print(tensor) enc_outputs = enc_model.forward(x_enc, windows_nb) outputs = enc_outputs.decrypt() outputs = torch.tensor(outputs).view(1, -1) return outputs
def op(): x_enc, windows_nb = ts.im2col_encoding( ctx, img.reshape(image_shape).tolist(), kernel_shape[0], kernel_shape[1], stride)
def prepare_input(ctx, plain_input): enc_input, windows_nb = ts.im2col_encoding(ctx, plain_input, 7, 7, 3) assert windows_nb == 64 return enc_input