def test_lwe_linear_broadcast(thread): params = NuFHEParameters() lwe_size = params.in_out_params.size res_shape = (10, 20) src_shape = res_shape[1:] res_a = get_test_array(res_shape + (lwe_size, ), Torus32) res_b = get_test_array(res_shape, Torus32) res_cv = get_test_array(res_shape, ErrorFloat, (-1, 1)) src_a = get_test_array(src_shape + (lwe_size, ), Torus32) src_b = get_test_array(src_shape, Torus32) src_cv = get_test_array(src_shape, ErrorFloat, (-1, 1)) coeff = 1 add_result = True res_shape_info = LweSampleArrayShapeInfo(res_a, res_b, res_cv) src_shape_info = LweSampleArrayShapeInfo(src_a, src_b, src_cv) test = LweLinear(res_shape_info, src_shape_info, add_result=add_result).compile(thread) ref = LweLinearReference(res_shape_info, src_shape_info, add_result=add_result) res_a_dev = thread.to_device(res_a) res_b_dev = thread.to_device(res_b) res_cv_dev = thread.to_device(res_cv) src_a_dev = thread.to_device(src_a) src_b_dev = thread.to_device(src_b) src_cv_dev = thread.to_device(src_cv) thread.synchronize() test(res_a_dev, res_b_dev, res_cv_dev, src_a_dev, src_b_dev, src_cv_dev, coeff) ref(res_a, res_b, res_cv, src_a, src_b, src_cv, coeff) assert (res_a_dev.get() == res_a).all() assert (res_b_dev.get() == res_b).all() assert errors_allclose(res_cv_dev.get(), res_cv)
def test_lwe_linear(thread, positive_coeff, add_result): params = NuFHEParameters() lwe_size = params.in_out_params.size shape = (10, 20) res_a = get_test_array(shape + (lwe_size, ), Torus32) res_b = get_test_array(shape, Torus32) res_cv = get_test_array(shape, Float, (-1, 1)) src_a = get_test_array(shape + (lwe_size, ), Torus32) src_b = get_test_array(shape, Torus32) src_cv = get_test_array(shape, Float, (-1, 1)) coeff = 1 if positive_coeff else -1 shape_info = LweSampleArrayShapeInfo(src_a, src_b, src_cv) test = LweLinear(shape_info, shape_info, add_result=add_result).compile(thread) ref = LweLinearReference(shape_info, shape_info, add_result=add_result) res_a_dev = thread.to_device(res_a) res_b_dev = thread.to_device(res_b) res_cv_dev = thread.to_device(res_cv) src_a_dev = thread.to_device(src_a) src_b_dev = thread.to_device(src_b) src_cv_dev = thread.to_device(src_cv) thread.synchronize() test(res_a_dev, res_b_dev, res_cv_dev, src_a_dev, src_b_dev, src_cv_dev, coeff) ref(res_a, res_b, res_cv, src_a, src_b, src_cv, coeff) assert (res_a_dev.get() == res_a).all() assert (res_b_dev.get() == res_b).all() assert numpy.allclose(res_cv_dev.get(), res_cv)