def combine_complex(): """ Returns a transformation which joins two real inputs into complex output (1 output, 2 inputs): ``output = input1 + 1j * input2``. """ return Transformation( inputs=2, outputs=1, derive_o_from_is=lambda i1, i2: dtypes.complex_for(i1), code="${o1.store}(COMPLEX_CTR(${o1.ctype})(${i1.load}, ${i2.load}));")
def split_complex(): """ Returns a transformation which splits complex input into two real outputs (2 outputs, 1 input): ``output1 = Re(input1), output2 = Im(input1)``. """ return Transformation( inputs=1, outputs=2, derive_i_from_os=lambda o1, o2: dtypes.complex_for(o1), code=""" ${o1.store}(${i1.load}.x); ${o2.store}(${i1.load}.y); """)
def test_errors(ctx_and_double, shapes, arg_dtypes): ctx, double = ctx_and_double s1, s2 = shapes c1, c2 = arg_dtypes dtype = numpy.float64 if double else numpy.float32 dtype1 = dtypes.complex_for(dtype) if c1 else dtype dtype2 = dtypes.complex_for(dtype) if c2 else dtype a = get_test_array(s1, dtype1) b = get_test_array(s2, dtype2) res_ref = ref_dot(a, b) a_dev = ctx.to_device(a) b_dev = ctx.to_device(b) res_dev = ctx.empty_like(res_ref) dot = MatrixMul(ctx).prepare_for(res_dev, a_dev, b_dev) dot(res_dev, a_dev, b_dev) assert diff_is_negligible(ctx.from_device(res_dev), res_ref)