def test_polynomial_multiplication(thread, transform_name): batch_shape = (10, ) if transform_name == 'FFT': transform = fft512() transform_ref = tr_fft.fft_transform_ref tr_mul_ref = tr_fft.fft_transformed_mul_ref else: transform = ntt1024() transform_ref = tr_ntt.ntt_transform_ref tr_mul_ref = tr_ntt.ntt_transformed_mul_ref tr_forward = Transform( transform, batch_shape, inverse=False, i32_conversion=True, transforms_per_block=1, ).compile(thread) tr_inverse = Transform( transform, batch_shape, inverse=True, i32_conversion=True, transforms_per_block=1, ).compile(thread) a = numpy.random.randint(-2**31, 2**31, size=batch_shape + (1024, ), dtype=numpy.int32) b = numpy.random.randint(-1000, 1000, size=batch_shape + (1024, ), dtype=numpy.int32) a_dev = thread.to_device(a) b_dev = thread.to_device(b) a_tr_dev = thread.empty_like(tr_forward.parameter.output) b_tr_dev = thread.empty_like(tr_forward.parameter.output) res_dev = thread.empty_like(tr_inverse.parameter.output) tr_forward(a_tr_dev, a_dev) tr_forward(b_tr_dev, b_dev) res_tr = tr_mul_ref(a_tr_dev.get(), b_tr_dev.get()) res_tr_dev = thread.to_device(res_tr) tr_inverse(res_dev, res_tr_dev) res_test = res_dev.get() res_ref = poly_mul_ref(a, b) assert numpy.allclose(res_test, res_ref)
def test_transform_correctness(thread, transform_type, inverse, i32_conversion, constant_memory): if not transform_supported(thread.device_params, transform_type): pytest.skip() batch_shape = (128,) if transform_type == 'FFT': transform = fft512(use_constant_memory=constant_memory) transform_ref = tr_fft.fft_transform_ref else: transform = ntt1024(use_constant_memory=constant_memory) transform_ref = tr_ntt.ntt_transform_ref comp = Transform( transform, batch_shape, inverse=inverse, i32_conversion=i32_conversion, transforms_per_block=1, ).compile(thread) a = get_test_array(comp.parameter.input.shape, comp.parameter.input.dtype) a_dev = thread.to_device(a) res_dev = thread.empty_like(comp.parameter.output) comp(res_dev, a_dev) res_test = res_dev.get() res_ref = transform_ref(a, inverse=inverse, i32_conversion=i32_conversion) if numpy.issubdtype(res_dev.dtype, numpy.integer): assert (res_test == res_ref).all() else: assert numpy.allclose(res_test, res_ref)
def test_transform_correctness(thread, transform_name, inverse, i32_conversion, constant_memory): batch_shape = (128, ) if transform_name == 'FFT': transform = fft512(use_constant_memory=constant_memory) transform_ref = tr_fft.fft_transform_ref else: transform = ntt1024(use_constant_memory=constant_memory) transform_ref = tr_ntt.ntt_transform_ref comp = Transform( transform, batch_shape, inverse=inverse, i32_conversion=i32_conversion, transforms_per_block=1, ).compile(thread) a = get_test_array(comp.parameter.input.shape, comp.parameter.input.dtype) a_dev = thread.to_device(a) res_dev = thread.empty_like(comp.parameter.output) comp(res_dev, a_dev) res_ref = transform_ref(a, inverse=inverse, i32_conversion=i32_conversion) assert numpy.allclose(res_dev.get(), res_ref)
def test_fft_performance(thread, transforms_per_block, constant_memory, heavy_performance_load): if not transform_supported(thread.device_params, 'FFT'): pytest.skip() if transforms_per_block > max_supported_transforms_per_block(thread.device_params, 'FFT'): pytest.skip() is_cuda = thread.api.get_id() == cuda_id() batch_shape = (2**14,) a = get_test_array(batch_shape + (512,), numpy.complex128) kernel_repetitions = 100 if heavy_performance_load else 5 a_dev = thread.to_device(a) res_dev = thread.empty_like(a_dev) res_ref = tr_fft.fft_transform_ref(a) transform = fft512(use_constant_memory=constant_memory) fft_comp = Transform( transform, batch_shape, transforms_per_block=transforms_per_block, ).compile(thread) fft_comp_repeated = Transform( transform, batch_shape, transforms_per_block=transforms_per_block, kernel_repetitions=kernel_repetitions).compile(thread) # Quick check of correctness fft_comp(res_dev, a_dev) res_test = res_dev.get() assert numpy.allclose(res_test, res_ref) # Test performance times, times_str = get_times(thread, fft_comp_repeated, res_dev, a_dev) print("\n{backend}, {trnum} per block, test --- {times}".format( times=times_str, backend='cuda' if is_cuda else 'ocl ', trnum=transforms_per_block))
def test_ntt_performance(thread, transforms_per_block, constant_memory, heavy_performance_load): if not transform_supported(thread.device_params, 'NTT'): pytest.skip() if transforms_per_block > max_supported_transforms_per_block(thread.device_params, 'NTT'): pytest.skip() is_cuda = thread.api.get_id() == cuda_id() methods = list(itertools.product( ['cuda_asm', 'c'], # base method ['cuda_asm', 'c_from_asm', 'c'], # mul method ['cuda_asm', 'c_from_asm', 'c'] # lsh method )) if not is_cuda: # filter out all usage of CUDA asm if we're on OpenCL methods = [ms for ms in methods if 'cuda_asm' not in ms] batch_shape = (2**14,) a = get_test_array(batch_shape + (1024,), "ff_number") kernel_repetitions = 100 if heavy_performance_load else 5 a_dev = thread.to_device(a) res_dev = thread.empty_like(a_dev) # TODO: compute a reference NTT when it's fast enough on CPU #res_ref = tr_ntt.ntt_transform_ref(a) print() min_times = [] for base_method, mul_method, lsh_method in methods: transform = ntt1024( base_method=base_method, mul_method=mul_method, lsh_method=lsh_method, use_constant_memory=constant_memory) ntt_comp = Transform( transform, batch_shape, transforms_per_block=transforms_per_block, ).compile(thread) ntt_comp_repeated = Transform( transform, batch_shape, transforms_per_block=transforms_per_block, kernel_repetitions=kernel_repetitions).compile(thread) # TODO: compute a reference NTT when it's fast enough on CPU # Quick check of correctness #ntt_comp(res_dev, a_dev) #res_test = res_dev.get() #assert (res_test == res_ref).all() # Test performance times, times_str = get_times(thread, ntt_comp_repeated, res_dev, a_dev) print(" base: {bm}, mul: {mm}, lsh: {lm}".format( bm=base_method, mm=mul_method, lm=lsh_method)) print(" {backend}, {trnum} per block, test --- {times}".format( times=times_str, backend='cuda' if is_cuda else 'ocl ', trnum=transforms_per_block)) min_times.append((times.min(), base_method, mul_method, lsh_method)) best = min(min_times, key=lambda t: t[0]) time_best, base_method, mul_method, lsh_method = best print("Best time: {tb:.4f} for [base: {bm}, mul: {mm}, lsh: {lm}]".format( tb=time_best, bm=base_method, mm=mul_method, lm=lsh_method ))