def input_transform(ymm_d): assert isinstance(ymm_d, list) and len(ymm_d) == 8 and all(isinstance(ymm, YMMRegister) for ymm in ymm_d) ymm_wd = [YMMRegister() for _ in range(8)] # wd0 = (d0 - d6) + 5.25 * (d4 - d2) # wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) # wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) # wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) # wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) # wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) # wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) # wd7 = (d7 - d1) + 5.25 * (d3 - d5) ymm_0_25 = YMMRegister() VMOVAPS(ymm_0_25, Constant.float32x8(0.25)) # Compute wd0 := d0 - d6 VSUBPS(ymm_wd[0], ymm_d[0], ymm_d[6]) ymm_d4_sub_d2 = YMMRegister() VSUBPS(ymm_d4_sub_d2, ymm_d[4], ymm_d[2]) # Compute wd7 := d7 - d1 VSUBPS(ymm_wd[7], ymm_d[7], ymm_d[1]) ymm_d3_sub_d5 = YMMRegister() VSUBPS(ymm_d3_sub_d5, ymm_d[3], ymm_d[5]) # Compute wd1 := d2 + d6 VADDPS(ymm_wd[1], ymm_d[2], ymm_d[6]) # Compute wd2 := d1 + d5 VADDPS(ymm_wd[2], ymm_d[1], ymm_d[5]) # Compute wd4 := d5 + 0.25 * d1 VMOVAPS(ymm_wd[4], ymm_d[5]) VFMADD231PS(ymm_wd[4], ymm_d[1], ymm_0_25) # Compute wd5 := d6 - 5.0 * d4 VMOVAPS(ymm_wd[5], Constant.float32x8(5.0)) VFNMADD132PS(ymm_wd[5], ymm_d[6], ymm_d[4]) # Compute wd3 := d6 + 0.25 * d2 VFMADD231PS(ymm_d[6], ymm_d[2], ymm_0_25) SWAP.REGISTERS(ymm_wd[3], ymm_d[6]) # Compute wd6 := d1 + 0.25 * d5 VFMADD231PS(ymm_d[1], ymm_d[5], ymm_0_25) SWAP.REGISTERS(ymm_wd[6], ymm_d[1]) ymm_5_25 = YMMRegister() VMOVAPS(ymm_5_25, Constant.float32x8(5.25)) # Compute wd0 := (d0 - d6) + 5.25 * (d4 - d2) VFMADD231PS(ymm_wd[0], ymm_d4_sub_d2, ymm_5_25) # Compute wd7 := (d7 - d1) + 5.25 * (d3 - d5) VFMADD231PS(ymm_wd[7], ymm_d3_sub_d5, ymm_5_25) ymm_4_25 = YMMRegister() VMOVAPS(ymm_4_25, Constant.float32x8(4.25)) # Compute # wd1 := (d6 + d2) - 4.25 * d4 # wd2 := (d1 + d5) - 4.25 * d3 VFNMADD231PS(ymm_wd[1], ymm_d[4], ymm_4_25) VFNMADD231PS(ymm_wd[2], ymm_d[3], ymm_4_25) ymm_1_25 = YMMRegister() VMOVAPS(ymm_1_25, Constant.float32x8(1.25)) # Compute # wd3 := (d6 + 0.25 * d2) - 1.25 * d4 # wd4 := (d5 + 0.25 * d1) - 1.25 * d3 # wd6 := (d1 + 0.25 * d5) - 1.25 * d3 # wd5 := (d6 - 5.0 * d4) + 4.0 * d2 VFNMADD231PS(ymm_wd[3], ymm_d[4], ymm_1_25) VFNMADD231PS(ymm_wd[4], ymm_d[3], ymm_1_25) VFMADD231PS(ymm_wd[5], ymm_d[2], Constant.float32x8(4.0)) VFNMADD231PS(ymm_wd[6], ymm_d[3], ymm_1_25) ymm_2 = YMMRegister() VMOVAPS(ymm_2, Constant.float32x8(2.0)) butterfly(ymm_wd[1], ymm_wd[2]) butterfly(ymm_wd[3], ymm_wd[4], scale_b=ymm_2) butterfly(ymm_wd[5], ymm_wd[6], scale_b=ymm_2) return ymm_wd
def ifft8_across_rows(ymm_data, bias=None): assert isinstance(ymm_data, list) and len(ymm_data) == 8 ymm_real = ymm_data[0::2] ymm_imag = ymm_data[1::2] if bias is None: # Do 1/N scaling before IFFT ymm_one_eighth = YMMRegister() VMOVAPS(ymm_one_eighth, Constant.float32x8(0.125)) for ymm_row in ymm_data: if ymm_row is ymm_real[2]: VMULPS(ymm_row, ymm_row, Constant.float32x8(0.25)) elif ymm_row is ymm_imag[2]: VMULPS(ymm_row, ymm_row, Constant.float32x8(-0.25)) else: VMULPS(ymm_row, ymm_row, ymm_one_eighth) else: # Do 1/N scaling after FFT (merge with bias addition) VMULPS(ymm_real[2], ymm_real[2], Constant.float32x8(2.0)) VMULPS(ymm_imag[2], ymm_imag[2], Constant.float32x8(-2.0)) butterfly(ymm_real[0], ymm_imag[0]) # H1.real, H1.imag = W1.real - W3.real, W1.imag + W3.imag ymm_h1_real, ymm_h1_imag = YMMRegister(), YMMRegister() VSUBPS(ymm_h1_real, ymm_real[1], ymm_real[3]) VADDPS(ymm_h1_imag, ymm_imag[1], ymm_imag[3]) # G1.real, G1.imag = W1.real + W3.real, W1.imag - W3.imag ymm_g1_real, ymm_g1_imag = YMMRegister(), YMMRegister() VADDPS(ymm_g1_real, ymm_real[1], ymm_real[3]) VSUBPS(ymm_g1_imag, ymm_imag[1], ymm_imag[3]) # H1+, H1- = H1.real + H1.imag, H1.real - H1.imag ymm_h1_plus, ymm_h1_minus = YMMRegister(), YMMRegister() VADDPS(ymm_h1_plus, ymm_h1_real, ymm_h1_imag) VSUBPS(ymm_h1_minus, ymm_h1_real, ymm_h1_imag) ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) # w1.real = G1.real - SQRT2_OVER_2 * H1.plus; # w3.real = G1.real + SQRT2_OVER_2 * H1.plus; VMOVAPS(ymm_real[1], ymm_g1_real) VFNMADD231PS(ymm_real[1], ymm_h1_plus, ymm_sqrt2_over_2) VFMADD231PS(ymm_g1_real, ymm_h1_plus, ymm_sqrt2_over_2) SWAP.REGISTERS(ymm_real[3], ymm_g1_real) # w1.imag = G1.imag + SQRT2_OVER_2 * H1.minus; # w3.imag = -G1.imag + SQRT2_OVER_2 * H1.minus; VMOVAPS(ymm_imag[1], ymm_g1_imag) VFMADD231PS(ymm_imag[1], ymm_h1_minus, ymm_sqrt2_over_2) VFMSUB231PS(ymm_g1_imag, ymm_h1_minus, ymm_sqrt2_over_2) SWAP.REGISTERS(ymm_imag[3], ymm_g1_imag) fft.complex_soa.fft4_across_rows(ymm_real, ymm_imag, transformation="inverse") if bias is not None: ymm_bias = bias if not isinstance(bias, YMMRegister): ymm_bias = YMMRegister() VMOVAPS(ymm_bias, bias) ymm_one_eighth = YMMRegister() VMOVAPS(ymm_one_eighth, Constant.float32x8(0.125)) # 1/N scaling for ymm_row in ymm_data: VFMADD132PS(ymm_row, ymm_bias, ymm_one_eighth)
def input_transform(ymm_d): assert isinstance(ymm_d, list) and len(ymm_d) == 8 and all( isinstance(ymm, YMMRegister) for ymm in ymm_d) ymm_wd = [YMMRegister() for _ in range(8)] # wd0 = (d0 - d6) + 5.25 * (d4 - d2) # wd1 = (d6 + d2 - 4.25 * d4) + (d1 + d5 - 4.25 * d3) # wd2 = (d6 + d2 - 4.25 * d4) - (d1 + d5 - 4.25 * d3) # wd3 = (d6 + 0.25 * d2 - 1.25 * d4) + 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) # wd4 = (d6 + 0.25 * d2 - 1.25 * d4) - 2.0 * (d5 + 0.25 * d1 - 1.25 * d3) # wd5 = (d6 - 5.0 * d4 + 4.0 * d2) + 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) # wd6 = (d6 - 5.0 * d4 + 4.0 * d2) - 2.0 * (d1 + 0.25 * d5 - 1.25 * d3) # wd7 = (d7 - d1) + 5.25 * (d3 - d5) ymm_0_25 = YMMRegister() VMOVAPS(ymm_0_25, Constant.float32x8(0.25)) # Compute wd0 := d0 - d6 VSUBPS(ymm_wd[0], ymm_d[0], ymm_d[6]) ymm_d4_sub_d2 = YMMRegister() VSUBPS(ymm_d4_sub_d2, ymm_d[4], ymm_d[2]) # Compute wd7 := d7 - d1 VSUBPS(ymm_wd[7], ymm_d[7], ymm_d[1]) ymm_d3_sub_d5 = YMMRegister() VSUBPS(ymm_d3_sub_d5, ymm_d[3], ymm_d[5]) # Compute wd1 := d2 + d6 VADDPS(ymm_wd[1], ymm_d[2], ymm_d[6]) # Compute wd2 := d1 + d5 VADDPS(ymm_wd[2], ymm_d[1], ymm_d[5]) # Compute wd4 := d5 + 0.25 * d1 VMOVAPS(ymm_wd[4], ymm_d[5]) VFMADD231PS(ymm_wd[4], ymm_d[1], ymm_0_25) # Compute wd5 := d6 - 5.0 * d4 VMOVAPS(ymm_wd[5], Constant.float32x8(5.0)) VFNMADD132PS(ymm_wd[5], ymm_d[6], ymm_d[4]) # Compute wd3 := d6 + 0.25 * d2 VFMADD231PS(ymm_d[6], ymm_d[2], ymm_0_25) SWAP.REGISTERS(ymm_wd[3], ymm_d[6]) # Compute wd6 := d1 + 0.25 * d5 VFMADD231PS(ymm_d[1], ymm_d[5], ymm_0_25) SWAP.REGISTERS(ymm_wd[6], ymm_d[1]) ymm_5_25 = YMMRegister() VMOVAPS(ymm_5_25, Constant.float32x8(5.25)) # Compute wd0 := (d0 - d6) + 5.25 * (d4 - d2) VFMADD231PS(ymm_wd[0], ymm_d4_sub_d2, ymm_5_25) # Compute wd7 := (d7 - d1) + 5.25 * (d3 - d5) VFMADD231PS(ymm_wd[7], ymm_d3_sub_d5, ymm_5_25) ymm_4_25 = YMMRegister() VMOVAPS(ymm_4_25, Constant.float32x8(4.25)) # Compute # wd1 := (d6 + d2) - 4.25 * d4 # wd2 := (d1 + d5) - 4.25 * d3 VFNMADD231PS(ymm_wd[1], ymm_d[4], ymm_4_25) VFNMADD231PS(ymm_wd[2], ymm_d[3], ymm_4_25) ymm_1_25 = YMMRegister() VMOVAPS(ymm_1_25, Constant.float32x8(1.25)) # Compute # wd3 := (d6 + 0.25 * d2) - 1.25 * d4 # wd4 := (d5 + 0.25 * d1) - 1.25 * d3 # wd6 := (d1 + 0.25 * d5) - 1.25 * d3 # wd5 := (d6 - 5.0 * d4) + 4.0 * d2 VFNMADD231PS(ymm_wd[3], ymm_d[4], ymm_1_25) VFNMADD231PS(ymm_wd[4], ymm_d[3], ymm_1_25) VFMADD231PS(ymm_wd[5], ymm_d[2], Constant.float32x8(4.0)) VFNMADD231PS(ymm_wd[6], ymm_d[3], ymm_1_25) ymm_2 = YMMRegister() VMOVAPS(ymm_2, Constant.float32x8(2.0)) butterfly(ymm_wd[1], ymm_wd[2]) butterfly(ymm_wd[3], ymm_wd[4], scale_b=ymm_2) butterfly(ymm_wd[5], ymm_wd[6], scale_b=ymm_2) return ymm_wd
def fft4_across_rows(ymm_real, ymm_imag, transformation="forward"): assert isinstance(ymm_real, list) and len(ymm_real) == 4 assert isinstance(ymm_imag, list) and len(ymm_imag) == 4 assert transformation in {"forward", "inverse"} ymm_data = sum(zip(ymm_real, ymm_imag), ()) # FFT-4 Butterfly for i in range(4): butterfly(ymm_data[i], ymm_data[i + 4]) # Multiply by FFT-4 twiddle factors SWAP.REGISTERS(ymm_real[3], ymm_imag[3]) # 2x FFT-2 Butterfly butterfly(ymm_data[0], ymm_data[2]) butterfly(ymm_data[1], ymm_data[3]) if transformation == "forward": butterfly(ymm_data[4], ymm_data[6]) butterfly(ymm_data[5], ymm_data[7], negate_b=True) else: butterfly(ymm_data[4], ymm_data[6], negate_b=True) butterfly(ymm_data[5], ymm_data[7]) # Bit reversal: not needed SWAP.REGISTERS(ymm_real[1], ymm_real[2]) SWAP.REGISTERS(ymm_imag[1], ymm_imag[2])
def ifft8_across_rows(ymm_data, bias=None): assert isinstance(ymm_data, list) and len(ymm_data) == 8 ymm_real = ymm_data[0::2] ymm_imag = ymm_data[1::2] if bias is None: # Do 1/N scaling before IFFT ymm_one_eighth = YMMRegister() VMOVAPS(ymm_one_eighth, Constant.float32x8(0.125)) for ymm_row in ymm_data: if ymm_row is ymm_real[2]: VMULPS(ymm_row, ymm_row, Constant.float32x8(0.25)) elif ymm_row is ymm_imag[2]: VMULPS(ymm_row, ymm_row, Constant.float32x8(-0.25)) else: VMULPS(ymm_row, ymm_row, ymm_one_eighth) else: # Do 1/N scaling after FFT (merge with bias addition) VMULPS(ymm_real[2], ymm_real[2], Constant.float32x8(2.0)) VMULPS(ymm_imag[2], ymm_imag[2], Constant.float32x8(-2.0)) butterfly(ymm_real[0], ymm_imag[0]) # H1.real, H1.imag = W1.real - W3.real, W1.imag + W3.imag ymm_h1_real, ymm_h1_imag = YMMRegister(), YMMRegister() VSUBPS(ymm_h1_real, ymm_real[1], ymm_real[3]) VADDPS(ymm_h1_imag, ymm_imag[1], ymm_imag[3]) # G1.real, G1.imag = W1.real + W3.real, W1.imag - W3.imag ymm_g1_real, ymm_g1_imag = YMMRegister(), YMMRegister() VADDPS(ymm_g1_real, ymm_real[1], ymm_real[3]) VSUBPS(ymm_g1_imag, ymm_imag[1], ymm_imag[3]) # H1+, H1- = H1.real + H1.imag, H1.real - H1.imag ymm_h1_plus, ymm_h1_minus = YMMRegister(), YMMRegister() VADDPS(ymm_h1_plus, ymm_h1_real, ymm_h1_imag) VSUBPS(ymm_h1_minus, ymm_h1_real, ymm_h1_imag) ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) # w1.real = G1.real - SQRT2_OVER_2 * H1.plus; # w3.real = G1.real + SQRT2_OVER_2 * H1.plus; VMOVAPS(ymm_real[1], ymm_g1_real) VFNMADD231PS(ymm_real[1], ymm_h1_plus, ymm_sqrt2_over_2) VFMADD231PS(ymm_g1_real, ymm_h1_plus, ymm_sqrt2_over_2) SWAP.REGISTERS(ymm_real[3], ymm_g1_real) # w1.imag = G1.imag + SQRT2_OVER_2 * H1.minus; # w3.imag = -G1.imag + SQRT2_OVER_2 * H1.minus; VMOVAPS(ymm_imag[1], ymm_g1_imag) VFMADD231PS(ymm_imag[1], ymm_h1_minus, ymm_sqrt2_over_2) VFMSUB231PS(ymm_g1_imag, ymm_h1_minus, ymm_sqrt2_over_2) SWAP.REGISTERS(ymm_imag[3], ymm_g1_imag) complex_soa.fft4_across_rows(ymm_real, ymm_imag, transformation="inverse") if bias is not None: ymm_bias = bias if not isinstance(bias, YMMRegister): ymm_bias = YMMRegister() VMOVAPS(ymm_bias, bias) ymm_one_eighth = YMMRegister() VMOVAPS(ymm_one_eighth, Constant.float32x8(0.125)) # 1/N scaling for ymm_row in ymm_data: VFMADD132PS(ymm_row, ymm_bias, ymm_one_eighth)
def fft16_within_rows(ymm_real_rows, ymm_imag_rows, bit_reversal=True): if isinstance(ymm_real_rows, tuple) and isinstance(ymm_imag_rows, tuple): return fft16_within_rows([ymm_real_rows], [ymm_imag_rows]) assert isinstance(ymm_real_rows, list) and all( isinstance(ymm_real, tuple) and all( isinstance(ymm, YMMRegister) for ymm in ymm_real) for ymm_real in ymm_real_rows) assert isinstance(ymm_imag_rows, list) and all( isinstance(ymm_imag, tuple) and all( isinstance(ymm, YMMRegister) for ymm in ymm_imag) for ymm_imag in ymm_imag_rows) # FFT16: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) # FFT16: Multiplication by twiddle factors ymm_fft16_cos_twiddle_factor, ymm_fft16_sin_twiddle_factor = YMMRegister( ), YMMRegister() VMOVAPS(ymm_fft16_cos_twiddle_factor, Constant.float32x8(*cos_npi_over_8)) VMOVAPS(ymm_fft16_sin_twiddle_factor, Constant.float32x8(*sin_npi_over_8)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft16_cos_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft16_cos_twiddle_factor) VFMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft16_sin_twiddle_factor) VFNMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft16_sin_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # 2x FFT8: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x128(ymm_real[0], ymm_real[1]) transpose2x2x128(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x2 x3 x8 x9 x10 x11 # w[1] = x4 x5 x6 x7 x12 x13 x14 x15 for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) # 2x FFT8: Multiplication by twiddle factors ymm_fft8_cos_twiddle_factor, ymm_fft8_sin_twiddle_factor = YMMRegister( ), YMMRegister() VMOVAPS(ymm_fft8_cos_twiddle_factor, Constant.float32x8(*(cos_npi_over_4 * 2))) VMOVAPS(ymm_fft8_sin_twiddle_factor, Constant.float32x8(*(sin_npi_over_4 * 2))) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft8_cos_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft8_cos_twiddle_factor) VFMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft8_sin_twiddle_factor) VFNMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft8_sin_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # 4x FFT4: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x2x64(ymm_real[0], ymm_real[1]) transpose2x2x2x64(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x4 x5 x8 x9 x12 x13 # w[1] = x2 x3 x6 x7 x10 x11 x14 x15 for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) # 4x FFT4: Multiplication by twiddle factors and 8x FFT2: Butterfly ymm_fft4_twiddle_factor = YMMRegister() VMOVAPS(ymm_fft4_twiddle_factor, Constant.float32x8(+1.0, +1.0, -1.0, -1.0, +1.0, +1.0, -1.0, -1.0)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real = YMMRegister(), YMMRegister() VSHUFPS(ymm_new_real[0], ymm_real[0], ymm_real[1], _MM_SHUFFLE(2, 0, 2, 0)) VSHUFPS(ymm_new_real[1], ymm_real[0], ymm_imag[1], _MM_SHUFFLE(3, 1, 3, 1)) butterfly(ymm_new_real[0], ymm_new_real[1]) ymm_new_imag = YMMRegister(), YMMRegister() VSHUFPS(ymm_new_imag[0], ymm_imag[0], ymm_imag[1], _MM_SHUFFLE(2, 0, 2, 0)) VSHUFPS(ymm_new_imag[1], ymm_imag[0], ymm_real[1], _MM_SHUFFLE(3, 1, 3, 1)) butterfly(ymm_new_imag[0], ymm_new_imag[1], scale_b=ymm_fft4_twiddle_factor) SWAP.REGISTERS(ymm_real[0], ymm_new_real[0]) SWAP.REGISTERS(ymm_real[1], ymm_new_real[1]) SWAP.REGISTERS(ymm_imag[0], ymm_new_imag[0]) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag[1]) # w[0] = x0 x4 x2 x6 x8 x12 x10 x14 # w[1] = x1 x5 x3 x7 x9 x11 x13 x15 if bit_reversal: # Bit reversal ymm_bit_reversal_mask = YMMRegister() VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 4, 1, 5, 2, 6, 3, 7)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): for i in range(2): VPERMPS(ymm_real[i], ymm_bit_reversal_mask, ymm_real[i]) VPERMPS(ymm_imag[i], ymm_bit_reversal_mask, ymm_imag[i])
def ifft16_within_rows(ymm_real_rows, ymm_imag_rows, bit_reversal=True): if isinstance(ymm_real_rows, tuple) and isinstance(ymm_imag_rows, tuple): return ifft16_within_rows([ymm_real_rows], [ymm_imag_rows]) assert isinstance(ymm_real_rows, list) and all( isinstance(ymm_real, tuple) and all( isinstance(ymm, YMMRegister) for ymm in ymm_real) for ymm_real in ymm_real_rows) assert isinstance(ymm_imag_rows, list) and all( isinstance(ymm_imag, tuple) and all( isinstance(ymm, YMMRegister) for ymm in ymm_imag) for ymm_imag in ymm_imag_rows) if bit_reversal: # Bit reversal # w[0] = x0 x8 x4 x12 x2 x10 x6 x14 # w[1] = x1 x9 x5 x13 x3 x11 x7 x15 ymm_bit_reversal_mask = YMMRegister() VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 2, 4, 6, 1, 3, 5, 7)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): for i in range(2): VPERMPS(ymm_real[i], ymm_bit_reversal_mask, ymm_real[i]) VPERMPS(ymm_imag[i], ymm_bit_reversal_mask, ymm_imag[i]) # 8x FFT2: Butterfly # w[0] = x0 x4 x2 x6 x8 x12 x10 x14 # w[1] = x1 x5 x3 x7 x9 x13 x11 x15 for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) ymm_new_real = YMMRegister(), YMMRegister() VUNPCKLPS(ymm_new_real[0], ymm_real[0], ymm_real[1]) VUNPCKHPS(ymm_new_real[1], ymm_real[0], ymm_imag[1]) ymm_new_imag = YMMRegister(), YMMRegister() VUNPCKLPS(ymm_new_imag[0], ymm_imag[0], ymm_imag[1]) VUNPCKHPS(ymm_new_imag[1], ymm_imag[0], ymm_real[1]) SWAP.REGISTERS(ymm_imag[0], ymm_new_imag[0]) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag[1]) SWAP.REGISTERS(ymm_real[0], ymm_new_real[0]) SWAP.REGISTERS(ymm_real[1], ymm_new_real[1]) # w[0] = x0 x1 x4 x5 x8 x9 x12 x13 # w[1] = x2 x3 x6 x7 x10 x11 x14 x15 # 4x FFT4: Butterfly and multiplication by twiddle factors ymm_fft4_twiddle_factor = YMMRegister() VMOVAPS(ymm_fft4_twiddle_factor, Constant.float32x8(+1.0, -1.0, +1.0, -1.0, +1.0, -1.0, +1.0, -1.0)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1], scale_b=ymm_fft4_twiddle_factor) butterfly(ymm_imag[0], ymm_imag[1]) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x2x64(ymm_real[0], ymm_real[1]) transpose2x2x2x64(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x2 x3 x8 x9 x10 x11 # w[1] = x4 x5 x6 x7 x12 x13 x14 x15 # 2x FFT8: Multiplication by twiddle factors ymm_fft8_cos_twiddle_factor, ymm_fft8_sin_twiddle_factor = YMMRegister( ), YMMRegister() VMOVAPS(ymm_fft8_cos_twiddle_factor, Constant.float32x8(*(cos_npi_over_4 * 2))) VMOVAPS(ymm_fft8_sin_twiddle_factor, Constant.float32x8(*(sin_npi_over_4 * 2))) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft8_cos_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft8_cos_twiddle_factor) VFNMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft8_sin_twiddle_factor) VFMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft8_sin_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # 2x FFT8: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x128(ymm_real[0], ymm_real[1]) transpose2x2x128(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x2 x3 x4 x5 x6 x7 # w[1] = x8 x9 x10 x11 x12 x13 x14 x15 # FFT16: Multiplication by twiddle factors and scale scale_factor = 0.0625 ymm_fft16_cos_scale_twiddle_factor, ymm_fft16_sin_scale_twiddle_factor = YMMRegister( ), YMMRegister() VMOVAPS( ymm_fft16_cos_scale_twiddle_factor, Constant.float32x8(*[cos * scale_factor for cos in cos_npi_over_8])) VMOVAPS( ymm_fft16_sin_scale_twiddle_factor, Constant.float32x8(*[sin * scale_factor for sin in sin_npi_over_8])) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft16_cos_scale_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft16_cos_scale_twiddle_factor) VFNMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft16_sin_scale_twiddle_factor) VFMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft16_sin_scale_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # FFT16: Butterfly and scale ymm_scale_factor = YMMRegister() VMOVAPS(ymm_scale_factor, Constant.float32x8(scale_factor)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1], scale_a=ymm_scale_factor) butterfly(ymm_imag[0], ymm_imag[1], scale_a=ymm_scale_factor)
def inverse_vfft(reg_t0, reg_t8, reg_t_stride, data_in, reg_row_start=None, reg_row_end=None, store_mask=None): assert isinstance(reg_t0, GeneralPurposeRegister64) assert isinstance(reg_t8, GeneralPurposeRegister64) assert isinstance(reg_t_stride, GeneralPurposeRegister64) assert isinstance(data_in, list) and len(data_in) == 16 assert reg_row_end is None or isinstance(reg_row_end, GeneralPurposeRegister32) assert store_mask is None or isinstance(store_mask, LocalVariable) and store_mask.size == YMMRegister.size in_real, in_imag = data_in[0::2], data_in[1::2] ymm_scale_factor = YMMRegister() VMOVAPS(ymm_scale_factor, Constant.float32x8(0.0625)) ymm_W1_real, ymm_W1_imag = YMMRegister(), YMMRegister() VMULPS(ymm_W1_real, ymm_scale_factor, in_real[1]) VMULPS(ymm_W1_imag, ymm_scale_factor, in_imag[1]) ymm_W2_real, ymm_W2_imag = YMMRegister(), YMMRegister() VMULPS(ymm_W2_real, ymm_scale_factor, in_real[2]) VMULPS(ymm_W2_imag, ymm_scale_factor, in_imag[2]) ymm_W3_real, ymm_W3_imag = YMMRegister(), YMMRegister() VMULPS(ymm_W3_real, ymm_scale_factor, in_real[3]) VMULPS(ymm_W3_imag, ymm_scale_factor, in_imag[3]) # G[n].real, H[n].real = W[n].real + W[8-n].real, W[n].real - W[8-n].real # G[n].imag, H[n].imag = W[n].imag - W[8-n].imag, W[n].imag + W[8-n].imag ymm_W7_real, ymm_W7_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_W7_real, in_real[7]) ymm_G1_real, ymm_H1_real = butterfly(ymm_W1_real, ymm_W7_real, scale_b=ymm_scale_factor) VMOVUPS(ymm_W7_imag, in_imag[7]) ymm_G1_imag, ymm_H1_imag = butterfly(ymm_W1_imag, ymm_W7_imag, scale_b=ymm_scale_factor, negate_b=True) ymm_W6_real, ymm_W6_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_W6_real, in_real[6]) ymm_G2_real, ymm_H2_real = butterfly(ymm_W2_real, ymm_W6_real, scale_b=ymm_scale_factor) VMOVUPS(ymm_W6_imag, in_imag[6]) ymm_G2_imag, ymm_H2_imag = butterfly(ymm_W2_imag, ymm_W6_imag, scale_b=ymm_scale_factor, negate_b=True) ymm_W5_real, ymm_W5_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_W5_real, in_real[5]) ymm_G3_real, ymm_H3_real = butterfly(ymm_W3_real, ymm_W5_real, scale_b=ymm_scale_factor) VMOVUPS(ymm_W5_imag, in_imag[5]) ymm_G3_imag, ymm_H3_imag = butterfly(ymm_W3_imag, ymm_W5_imag, scale_b=ymm_scale_factor, negate_b=True) # H[2]+, H[2]- = H[2].real + H[2].imag, H[2].real - H[2].imag ymm_H2_add, ymm_H2_sub = butterfly(ymm_H2_real, ymm_H2_imag) # w[ n].real = G[ n].real - H[ n].real * cos((N-n)*pi/2N) - H[ n].imag * cos(n*pi/2N) # w[2N-n].real = G[ n].real + H[ n].real * cos((N-n)*pi/2N) + H[ n].imag * cos(n*pi/2N) # w[ n].imag = G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[2N-n].imag = -G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[ N-n].real = G[N-n].real - H[N-n].real * cos(n*pi/2N) - H[N-n].imag * cos((N-n)*pi/2N) # w[ N+n].real = G[N-n].real + H[N-n].real * cos(n*pi/2N) + H[N-n].imag * cos((N-n)*pi/2N) # w[ N-n].imag = G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) # w[ N+n].imag = -G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) ymm_cos_1pi_over_8, ymm_cos_3pi_over_8 = YMMRegister(), YMMRegister() VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3])) VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1])) ymm_w1_real, ymm_w7_real = YMMRegister(), ymm_G1_real VMOVAPS(ymm_w1_real, ymm_G1_real) VFNMADD231PS(ymm_w1_real, ymm_H1_real, ymm_cos_3pi_over_8) VFMADD231PS(ymm_w7_real, ymm_H1_real, ymm_cos_3pi_over_8) ymm_w1_imag, ymm_w7_imag = YMMRegister(), ymm_G1_imag VMOVAPS(ymm_w1_imag, ymm_G1_imag) VFMADD231PS(ymm_w1_imag, ymm_H1_real, ymm_cos_1pi_over_8) VFMSUB231PS(ymm_w7_imag, ymm_H1_real, ymm_cos_1pi_over_8) ymm_w3_real, ymm_w5_real = YMMRegister(), ymm_G3_real VMOVAPS(ymm_w3_real, ymm_G3_real) VFNMADD231PS(ymm_w3_real, ymm_H3_real, ymm_cos_1pi_over_8) VFMADD231PS(ymm_w5_real, ymm_H3_real, ymm_cos_1pi_over_8) ymm_w3_imag, ymm_w5_imag = YMMRegister(), ymm_G3_imag VMOVAPS(ymm_w3_imag, ymm_G3_imag) VFMADD231PS(ymm_w3_imag, ymm_H3_real, ymm_cos_3pi_over_8) VFMSUB231PS(ymm_w5_imag, ymm_H3_real, ymm_cos_3pi_over_8) ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) # w[ N/2].real = G[N/2].real - H[N/2]+ * sqrt(2)/2 # w[ N/2].imag = G[N/2].imag + H[N/2]- * sqrt(2)/2 # w[3N/2].real = G[N/2].real + H[N/2]+ * sqrt(2)/2 # w[3N/2].imag = -G[N/2].imag + H[N/2]- * sqrt(2)/2 ymm_w2_real, ymm_w6_real = YMMRegister(), ymm_G2_real VMOVAPS(ymm_w2_real, ymm_G2_real) VFNMADD231PS(ymm_w2_real, ymm_H2_add, ymm_sqrt2_over_2) VFMADD231PS(ymm_w6_real, ymm_H2_add, ymm_sqrt2_over_2) ymm_w2_imag, ymm_w6_imag = YMMRegister(), ymm_G2_imag VMOVAPS(ymm_w2_imag, ymm_G2_imag) VFMADD231PS(ymm_w2_imag, ymm_H2_sub, ymm_sqrt2_over_2) VFMSUB231PS(ymm_w6_imag, ymm_H2_sub, ymm_sqrt2_over_2) # w[ n].real = G[ n].real - H[ n].real * cos((N-n)*pi/2N) - H[ n].imag * cos(n*pi/2N) # w[2N-n].real = G[ n].real + H[ n].real * cos((N-n)*pi/2N) + H[ n].imag * cos(n*pi/2N) # w[ n].imag = G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[2N-n].imag = -G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[ N-n].real = G[N-n].real - H[N-n].real * cos(n*pi/2N) - H[N-n].imag * cos((N-n)*pi/2N) # w[ N+n].real = G[N-n].real + H[N-n].real * cos(n*pi/2N) + H[N-n].imag * cos((N-n)*pi/2N) # w[ N-n].imag = G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) # w[ N+n].imag = -G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) ymm_cos_1pi_over_8, ymm_cos_3pi_over_8 = YMMRegister(), YMMRegister() VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1])) VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3])) VFNMADD231PS(ymm_w1_real, ymm_H1_imag, ymm_cos_1pi_over_8) VFMADD231PS(ymm_w7_real, ymm_H1_imag, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_w1_imag, ymm_H1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_w7_imag, ymm_H1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_w3_real, ymm_H3_imag, ymm_cos_3pi_over_8) VFMADD231PS(ymm_w5_real, ymm_H3_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_w3_imag, ymm_H3_imag, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_w5_imag, ymm_H3_imag, ymm_cos_1pi_over_8) data = [ LocalVariable(YMMRegister.size), YMMRegister(), ymm_w1_real, ymm_w1_imag, ymm_w2_real, ymm_w2_imag, ymm_w3_real, ymm_w3_imag, LocalVariable(YMMRegister.size), LocalVariable(YMMRegister.size), ymm_w5_real, ymm_w5_imag, ymm_w6_real, ymm_w6_imag, ymm_w7_real, ymm_w7_imag ] real, imag = data[0::2], data[1::2] # TODO: optimize ymm_w0_real, ymm_w0_imag = YMMRegister(), imag[0] VMOVUPS(ymm_w0_real, in_real[0]) VMOVUPS(ymm_w0_imag, in_imag[0]) VMULPS(ymm_w0_real, ymm_w0_real, Constant.float32x8(0.0625)) butterfly(ymm_w0_real, ymm_w0_imag, scale_b=Constant.float32x8(0.0625)) VMOVAPS(real[0], ymm_w0_real) # TODO: optimize ymm_w4_real, ymm_w4_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_w4_real, in_real[4]) VMOVUPS(ymm_w4_imag, in_imag[4]) VMULPS(ymm_w4_real, ymm_w4_real, Constant.float32x8(0.125)) VMULPS(ymm_w4_imag, ymm_w4_imag, Constant.float32x8(-0.125)) VMOVAPS(real[4], ymm_w4_real) VMOVAPS(imag[4], ymm_w4_imag) # Bit reversal for i in range(8): new_i = fft8_bitreverse(i) if new_i > i: real[i], real[new_i] = real[new_i], real[i] imag[i], imag[new_i] = imag[new_i], imag[i] data = interleave(real, imag) # 4x FFT2: butterfly for i, (data_lo, data_hi) in enumerate(zip(data[0:2] + data[4:6] + data[8:10] + data[12:14], data[2:4] + data[6:8] + data[10:12] + data[14:16])): butterfly(data_lo, data_hi) # 2x FFT4: multiplication by twiddle factors fft4_scale_b, fft4_negate_b = {}, {} fft8_scale_b, fft8_negate_b = {}, {} # w3.re, w3.im = -w3.im, w3.re # w7.re, w7.im = -w7.im, w7.re SWAP.REGISTERS(real[3], imag[3]) fft4_negate_b[id(real[3])] = True SWAP.REGISTERS(real[7], imag[7]) fft4_negate_b[id(real[7])] = True # 2x FFT4: butterfly for data_lo, data_hi in zip(data[0:4] + data[8:12], data[4:8] + data[12:16]): butterfly(data_lo, data_hi, negate_b=fft4_negate_b.get(id(data_hi), False)) # FFT8: multiplication by twiddle factors # w6.re, w6.im = -w6.im, w6.re SWAP.REGISTERS(real[6], imag[6]) fft8_negate_b[id(real[6])] = True # w5.re, w5.im = SQRT2_OVER_2 * (w5.re - w5.im), SQRT2_OVER_2 * (w5.re + w5.im) butterfly(real[5], imag[5], negate_b=True) fft8_scale_b[id(real[5])] = Constant.float32x8(sqrt2_over_2) fft8_scale_b[id(imag[5])] = Constant.float32x8(sqrt2_over_2) # w7.re, w7.im = -SQRT2_OVER_2 * (w7.re + w7.im), SQRT2_OVER_2 * (w7.re - w7.im) butterfly(real[7], imag[7]) fft8_scale_b[id(real[7])] = Constant.float32x8(sqrt2_over_2) fft8_negate_b[id(real[7])] = True fft8_scale_b[id(imag[7])] = Constant.float32x8(sqrt2_over_2) ymm_store_mask = YMMRegister() if store_mask: VMOVAPS(ymm_store_mask, store_mask) # FFT8: butterfly with Block() as store_data: for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])): row_lo = i row_hi = row_lo + 8 ymm_data_lo, ymm_data_hi = \ butterfly(data_lo, data_hi, scale_b=fft8_scale_b.get(id(data_hi)), negate_b=fft8_negate_b.get(id(data_hi), False), writeback=False) with Block() as store_data_lo: if reg_row_start: CMP(reg_row_start, row_lo) JA(store_data_lo.end) if reg_row_end: CMP(reg_row_end, row_lo) JBE(store_data_lo.end) elif reg_row_end: CMP(reg_row_end, row_lo) JBE(store_data.end) if store_mask: VMASKMOVPS([reg_t0], ymm_store_mask, ymm_data_lo) else: VMOVUPS([reg_t0], ymm_data_lo) if i + 1 != 8: ADD(reg_t0, reg_t_stride) with Block() as store_data_hi: if reg_row_start: CMP(reg_row_start, row_hi) JA(store_data_hi.end) if reg_row_end: CMP(reg_row_end, row_hi) JBE(store_data_hi.end) if store_mask: VMASKMOVPS([reg_t8], ymm_store_mask, ymm_data_hi) else: VMOVUPS([reg_t8], ymm_data_hi) if i + 1 != 8: ADD(reg_t8, reg_t_stride)
def forward_vfft(reg_t0, reg_t8, reg_t_stride, data_out, reg_row_start=None, reg_row_end=None, ymm_load_mask=None): assert isinstance(reg_t0, GeneralPurposeRegister64) assert isinstance(reg_t8, GeneralPurposeRegister64) assert isinstance(reg_t_stride, GeneralPurposeRegister64) assert isinstance(data_out, list) and len(data_out) == 16 assert ymm_load_mask is None or isinstance(ymm_load_mask, YMMRegister) out_real, out_imag = data_out[0::2], data_out[1::2] real, imag = [YMMRegister() for _ in range(8)], [YMMRegister() for _ in range(8)] imag[0] = LocalVariable(YMMRegister.size) imag[4] = LocalVariable(YMMRegister.size) data = interleave(real, imag) for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])): row_lo = i row_hi = row_lo + 8 ymm_data_lo, ymm_data_hi = data_lo, data_hi if isinstance(data_lo, LocalVariable): ymm_data_lo = YMMRegister() if isinstance(data_hi, LocalVariable): ymm_data_hi = YMMRegister() VXORPS(ymm_data_lo, ymm_data_lo, ymm_data_lo) skip_data_lo = Label() if reg_row_start: CMP(reg_row_start, row_lo) JA(skip_data_lo) if reg_row_end: CMP(reg_row_end, row_lo) JBE(skip_data_lo) if ymm_load_mask is None: VMOVUPS(ymm_data_lo, [reg_t0]) else: VMASKMOVPS(ymm_data_lo, ymm_load_mask, [reg_t0]) if i + 1 != 8: ADD(reg_t0, reg_t_stride) LABEL(skip_data_lo) VMOVAPS(ymm_data_hi, ymm_data_lo) skip_data_hi = Label() if reg_row_start: CMP(reg_row_start, row_hi) JA(skip_data_hi) if reg_row_end: CMP(reg_row_end, row_hi) JBE(skip_data_hi) if ymm_load_mask is None: VMOVUPS(ymm_data_hi, [reg_t8]) butterfly(ymm_data_lo, ymm_data_hi) else: ymm_temp_hi = YMMRegister() VMASKMOVPS(ymm_temp_hi, ymm_load_mask, [reg_t8]) VSUBPS(ymm_data_hi, ymm_data_lo, ymm_temp_hi) VADDPS(ymm_data_lo, ymm_data_lo, ymm_temp_hi) if i + 1 != 8: ADD(reg_t8, reg_t_stride) LABEL(skip_data_hi) if isinstance(data_lo, LocalVariable): VMOVAPS(data_lo, ymm_data_lo) if isinstance(data_hi, LocalVariable): VMOVAPS(data_hi, ymm_data_hi) # FFT8: multiplication by twiddle factors fft4_scale_b, fft4_negate_b = {}, {} fft2_scale_b, fft2_negate_b = {}, {} # w6.re, w6.im = w6.im, -w6.re SWAP.REGISTERS(real[6], imag[6]) fft4_negate_b[id(imag[6])] = True # w5.re, w5.im = SQRT2_OVER_2 * (w5.re + w5.im), SQRT2_OVER_2 * (w5.im - w5.re) butterfly(imag[5], real[5]) SWAP.REGISTERS(real[5], imag[5]) # w7.re, w7.im = -SQRT2_OVER_2 * (w7.re - w7.im), -SQRT2_OVER_2 * (w7.re + w7.im) butterfly(real[7], imag[7], negate_b=True) fft4_negate_b[id(real[7])] = True fft4_negate_b[id(imag[7])] = True # Propogate multiplication by sqrt2_over_2 until the last butterfly in FFT2 ymm_sqrt2_over_2 = YMMRegister() fft2_scale_b[id(real[5])] = ymm_sqrt2_over_2 fft2_scale_b[id(imag[5])] = ymm_sqrt2_over_2 fft2_scale_b[id(real[7])] = ymm_sqrt2_over_2 fft2_scale_b[id(imag[7])] = ymm_sqrt2_over_2 # 2x FFT4: butterfly for data_lo, data_hi in zip(data[0:4] + data[8:12], data[4:8] + data[12:16]): butterfly(data_lo, data_hi, negate_b=fft4_negate_b.get(id(data_hi), False), scale_b=fft4_scale_b.get(id(data_hi))) # 2x FFT4: multiplication by twiddle factors # w3.re, w3.im = w3.im, -w3.re # w7.re, w7.im = w7.im, -w7.re SWAP.REGISTERS(real[3], imag[3]) SWAP.REGISTERS(real[7], imag[7]) fft2_negate_b[id(imag[3])] = True fft2_negate_b[id(imag[7])] = True # 4x FFT2: butterfly # Process the first two elements separately ymm_real0, ymm_real1 = butterfly(real[0], real[1], writeback=False) store_ymm_result(out_real[4], ymm_real1) # bit-reversal: 1->4 ymm_imag0, ymm_imag1 = butterfly(imag[0], imag[1], negate_out_b=True, writeback=False) store_ymm_result(out_imag[4], ymm_imag1) # bit-reversal: 1->4 VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) for i, (data_lo, data_hi) in enumerate(zip(data[4:6] + data[8:10] + data[12:14], data[6:8] + data[10:12] + data[14:16])): butterfly(data_lo, data_hi, negate_b=fft2_negate_b.get(id(data_hi), False), scale_b=fft2_scale_b.get(id(data_hi))) butterfly(ymm_real0, ymm_imag0) store_ymm_result(out_real[0], ymm_real0) store_ymm_result(out_imag[0], ymm_imag0) # Bit reversal for i in range(8): new_i = fft8_bitreverse(i) if new_i > i: real[i], real[new_i] = real[new_i], real[i] imag[i], imag[new_i] = imag[new_i], imag[i] data = interleave(real, imag) ymm_two_g2_real, ymm_two_g2_imag = YMMRegister(), YMMRegister() ymm_two_h2_real, ymm_two_h2_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_g2_real, real[2], real[6]) VSUBPS(ymm_two_h2_imag, real[6], real[2]) VSUBPS(ymm_two_g2_imag, imag[2], imag[6]) VADDPS(ymm_two_h2_real, imag[2], imag[6]) ymm_two_g1_real, ymm_two_g1_imag = YMMRegister(), YMMRegister() ymm_two_h1_real, ymm_two_h1_imag = YMMRegister(), YMMRegister() ymm_real1 = load_ymm_variable(real[1]) VADDPS(ymm_two_g1_real, ymm_real1, real[7]) VSUBPS(ymm_two_h1_imag, real[7], ymm_real1) ymm_imag1 = load_ymm_variable(imag[1]) VSUBPS(ymm_two_g1_imag, ymm_imag1, imag[7]) VADDPS(ymm_two_h1_real, ymm_imag1, imag[7]) ymm_two_h2_add, ymm_two_h2_sub = YMMRegister(), YMMRegister() VADDPS(ymm_two_h2_add, ymm_two_h2_real, ymm_two_h2_imag) VSUBPS(ymm_two_h2_sub, ymm_two_h2_imag, ymm_two_h2_real) ymm_two_g3_real, ymm_two_g3_imag = YMMRegister(), YMMRegister() ymm_two_h3_real, ymm_two_h3_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_g3_real, real[3], real[5]) VSUBPS(ymm_two_h3_imag, real[5], real[3]) VSUBPS(ymm_two_g3_imag, imag[3], imag[5]) VADDPS(ymm_two_h3_real, imag[3], imag[5]) # const float two_w2_real = two_g2_real + SQRT2_OVER_2 * (two_h2_real + two_h2_imag); # const float two_w2_imag = two_g2_imag + SQRT2_OVER_2 * (two_h2_imag - two_h2_real); # const float two_w6_real = two_g2_real - SQRT2_OVER_2 * (two_h2_real + two_h2_imag); # const float two_w6_imag = -two_g2_imag + SQRT2_OVER_2 * (two_h2_imag - two_h2_real); ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) ymm_two_w2_real, ymm_two_w6_real = YMMRegister(), ymm_two_g2_real VMOVAPS(ymm_two_w2_real, ymm_two_g2_real) VFMADD231PS(ymm_two_w2_real, ymm_two_h2_add, ymm_sqrt2_over_2) VFNMADD231PS(ymm_two_w6_real, ymm_two_h2_add, ymm_sqrt2_over_2) ymm_two_w2_imag, ymm_two_w6_imag = YMMRegister(), ymm_two_g2_imag VMOVAPS(ymm_two_w2_imag, ymm_two_g2_imag) VFMADD231PS(ymm_two_w2_imag, ymm_two_h2_sub, ymm_sqrt2_over_2) VFMSUB231PS(ymm_two_w6_imag, ymm_two_h2_sub, ymm_sqrt2_over_2) ymm_half = YMMRegister() VMOVAPS(ymm_half, Constant.float32x8(0.5)) VMULPS(ymm_two_w2_real, ymm_two_w2_real, ymm_half) store_ymm_result(out_real[2], ymm_two_w2_real) VMULPS(ymm_two_w6_real, ymm_two_w6_real, ymm_half) store_ymm_result(out_real[6], ymm_two_w6_real) VMULPS(ymm_two_w2_imag, ymm_two_w2_imag, ymm_half) store_ymm_result(out_imag[2], ymm_two_w2_imag) VMULPS(ymm_two_w6_imag, ymm_two_w6_imag, ymm_half) store_ymm_result(out_imag[6], ymm_two_w6_imag) # const float two_w1_real = two_g1_real + two_h1_real * COS_1PI_OVER_8 + two_h1_imag * COS_3PI_OVER_8; # const float two_w1_imag = two_g1_imag + two_h1_imag * COS_1PI_OVER_8 - two_h1_real * COS_3PI_OVER_8; # const float two_w7_real = two_g1_real - two_h1_real * COS_1PI_OVER_8 - two_h1_imag * COS_3PI_OVER_8; # const float two_w7_imag = -two_g1_imag + two_h1_imag * COS_1PI_OVER_8 - two_h1_real * COS_3PI_OVER_8; # const float two_w3_real = two_g3_real + two_h3_real * COS_3PI_OVER_8 + two_h3_imag * COS_1PI_OVER_8; # const float two_w3_imag = two_g3_imag + two_h3_imag * COS_3PI_OVER_8 - two_h3_real * COS_1PI_OVER_8; # const float two_w5_real = two_g3_real - two_h3_real * COS_3PI_OVER_8 - two_h3_imag * COS_1PI_OVER_8; # const float two_w5_imag = -two_g3_imag + two_h3_imag * COS_3PI_OVER_8 - two_h3_real * COS_1PI_OVER_8; ymm_cos_1pi_over_8 = YMMRegister() VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1])) ymm_two_w1_real, ymm_two_w7_real = YMMRegister(), ymm_two_g1_real VMOVAPS(ymm_two_w1_real, ymm_two_g1_real) VFMADD231PS(ymm_two_w1_real, ymm_two_h1_real, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_two_w7_real, ymm_two_h1_real, ymm_cos_1pi_over_8) ymm_two_w1_imag, ymm_two_w7_imag = YMMRegister(), ymm_two_g1_imag VMOVAPS(ymm_two_w1_imag, ymm_two_g1_imag) VFMADD231PS(ymm_two_w1_imag, ymm_two_h1_imag, ymm_cos_1pi_over_8) VFMSUB231PS(ymm_two_w7_imag, ymm_two_h1_imag, ymm_cos_1pi_over_8) ymm_two_w3_real, ymm_two_w5_real = YMMRegister(), ymm_two_g3_real VMOVAPS(ymm_two_w3_real, ymm_two_g3_real) VFMADD231PS(ymm_two_w3_real, ymm_two_h3_imag, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_two_w5_real, ymm_two_h3_imag, ymm_cos_1pi_over_8) ymm_two_w3_imag, ymm_two_w5_imag = YMMRegister(), ymm_two_g3_imag VMOVAPS(ymm_two_w3_imag, ymm_two_g3_imag) VFNMADD231PS(ymm_two_w3_imag, ymm_two_h3_real, ymm_cos_1pi_over_8) VFNMSUB231PS(ymm_two_w5_imag, ymm_two_h3_real, ymm_cos_1pi_over_8) ymm_cos_3pi_over_8 = YMMRegister() VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3])) VFMADD231PS(ymm_two_w1_real, ymm_two_h1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w7_real, ymm_two_h1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w1_imag, ymm_two_h1_real, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w7_imag, ymm_two_h1_real, ymm_cos_3pi_over_8) VFMADD231PS(ymm_two_w3_real, ymm_two_h3_real, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w5_real, ymm_two_h3_real, ymm_cos_3pi_over_8) VFMADD231PS(ymm_two_w3_imag, ymm_two_h3_imag, ymm_cos_3pi_over_8) VFMADD231PS(ymm_two_w5_imag, ymm_two_h3_imag, ymm_cos_3pi_over_8) ymm_half = YMMRegister() VMOVAPS(ymm_half, Constant.float32x8(0.5)) VMULPS(ymm_two_w1_real, ymm_two_w1_real, ymm_half) store_ymm_result(out_real[1], ymm_two_w1_real) VMULPS(ymm_two_w7_real, ymm_two_w7_real, ymm_half) store_ymm_result(out_real[7], ymm_two_w7_real) VMULPS(ymm_two_w1_imag, ymm_two_w1_imag, ymm_half) store_ymm_result(out_imag[1], ymm_two_w1_imag) VMULPS(ymm_two_w7_imag, ymm_two_w7_imag, ymm_half) store_ymm_result(out_imag[7], ymm_two_w7_imag) VMULPS(ymm_two_w3_real, ymm_two_w3_real, ymm_half) store_ymm_result(out_real[3], ymm_two_w3_real) VMULPS(ymm_two_w5_real, ymm_two_w5_real, ymm_half) store_ymm_result(out_real[5], ymm_two_w5_real) VMULPS(ymm_two_w3_imag, ymm_two_w3_imag, ymm_half) store_ymm_result(out_imag[3], ymm_two_w3_imag) VMULPS(ymm_two_w5_imag, ymm_two_w5_imag, ymm_half) store_ymm_result(out_imag[5], ymm_two_w5_imag)
def ifft16_within_rows(ymm_real_rows, ymm_imag_rows, bit_reversal=True): if isinstance(ymm_real_rows, tuple) and isinstance(ymm_imag_rows, tuple): return ifft16_within_rows([ymm_real_rows], [ymm_imag_rows]) assert isinstance(ymm_real_rows, list) and all(isinstance(ymm_real, tuple) and all(isinstance(ymm, YMMRegister) for ymm in ymm_real) for ymm_real in ymm_real_rows) assert isinstance(ymm_imag_rows, list) and all(isinstance(ymm_imag, tuple) and all(isinstance(ymm, YMMRegister) for ymm in ymm_imag) for ymm_imag in ymm_imag_rows) if bit_reversal: # Bit reversal # w[0] = x0 x8 x4 x12 x2 x10 x6 x14 # w[1] = x1 x9 x5 x13 x3 x11 x7 x15 ymm_bit_reversal_mask = YMMRegister() VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 2, 4, 6, 1, 3, 5, 7)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): for i in range(2): VPERMPS(ymm_real[i], ymm_bit_reversal_mask, ymm_real[i]) VPERMPS(ymm_imag[i], ymm_bit_reversal_mask, ymm_imag[i]) # 8x FFT2: Butterfly # w[0] = x0 x4 x2 x6 x8 x12 x10 x14 # w[1] = x1 x5 x3 x7 x9 x13 x11 x15 for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) ymm_new_real = YMMRegister(), YMMRegister() VUNPCKLPS(ymm_new_real[0], ymm_real[0], ymm_real[1]) VUNPCKHPS(ymm_new_real[1], ymm_real[0], ymm_imag[1]) ymm_new_imag = YMMRegister(), YMMRegister() VUNPCKLPS(ymm_new_imag[0], ymm_imag[0], ymm_imag[1]) VUNPCKHPS(ymm_new_imag[1], ymm_imag[0], ymm_real[1]) SWAP.REGISTERS(ymm_imag[0], ymm_new_imag[0]) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag[1]) SWAP.REGISTERS(ymm_real[0], ymm_new_real[0]) SWAP.REGISTERS(ymm_real[1], ymm_new_real[1]) # w[0] = x0 x1 x4 x5 x8 x9 x12 x13 # w[1] = x2 x3 x6 x7 x10 x11 x14 x15 # 4x FFT4: Butterfly and multiplication by twiddle factors ymm_fft4_twiddle_factor = YMMRegister() VMOVAPS(ymm_fft4_twiddle_factor, Constant.float32x8(+1.0, -1.0, +1.0, -1.0, +1.0, -1.0, +1.0, -1.0)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1], scale_b=ymm_fft4_twiddle_factor) butterfly(ymm_imag[0], ymm_imag[1]) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x2x64(ymm_real[0], ymm_real[1]) transpose2x2x2x64(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x2 x3 x8 x9 x10 x11 # w[1] = x4 x5 x6 x7 x12 x13 x14 x15 # 2x FFT8: Multiplication by twiddle factors ymm_fft8_cos_twiddle_factor, ymm_fft8_sin_twiddle_factor = YMMRegister(), YMMRegister() VMOVAPS(ymm_fft8_cos_twiddle_factor, Constant.float32x8(*(cos_npi_over_4 * 2))) VMOVAPS(ymm_fft8_sin_twiddle_factor, Constant.float32x8(*(sin_npi_over_4 * 2))) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft8_cos_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft8_cos_twiddle_factor) VFNMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft8_sin_twiddle_factor) VFMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft8_sin_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # 2x FFT8: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x128(ymm_real[0], ymm_real[1]) transpose2x2x128(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x2 x3 x4 x5 x6 x7 # w[1] = x8 x9 x10 x11 x12 x13 x14 x15 # FFT16: Multiplication by twiddle factors and scale scale_factor = 0.0625 ymm_fft16_cos_scale_twiddle_factor, ymm_fft16_sin_scale_twiddle_factor = YMMRegister(), YMMRegister() VMOVAPS(ymm_fft16_cos_scale_twiddle_factor, Constant.float32x8(*[cos * scale_factor for cos in cos_npi_over_8])) VMOVAPS(ymm_fft16_sin_scale_twiddle_factor, Constant.float32x8(*[sin * scale_factor for sin in sin_npi_over_8])) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft16_cos_scale_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft16_cos_scale_twiddle_factor) VFNMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft16_sin_scale_twiddle_factor) VFMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft16_sin_scale_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # FFT16: Butterfly and scale ymm_scale_factor = YMMRegister() VMOVAPS(ymm_scale_factor, Constant.float32x8(scale_factor)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1], scale_a=ymm_scale_factor) butterfly(ymm_imag[0], ymm_imag[1], scale_a=ymm_scale_factor)
def fft16_within_rows(ymm_real_rows, ymm_imag_rows, bit_reversal=True): if isinstance(ymm_real_rows, tuple) and isinstance(ymm_imag_rows, tuple): return fft16_within_rows([ymm_real_rows], [ymm_imag_rows]) assert isinstance(ymm_real_rows, list) and all(isinstance(ymm_real, tuple) and all(isinstance(ymm, YMMRegister) for ymm in ymm_real) for ymm_real in ymm_real_rows) assert isinstance(ymm_imag_rows, list) and all(isinstance(ymm_imag, tuple) and all(isinstance(ymm, YMMRegister) for ymm in ymm_imag) for ymm_imag in ymm_imag_rows) # FFT16: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) # FFT16: Multiplication by twiddle factors ymm_fft16_cos_twiddle_factor, ymm_fft16_sin_twiddle_factor = YMMRegister(), YMMRegister() VMOVAPS(ymm_fft16_cos_twiddle_factor, Constant.float32x8(*cos_npi_over_8)) VMOVAPS(ymm_fft16_sin_twiddle_factor, Constant.float32x8(*sin_npi_over_8)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft16_cos_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft16_cos_twiddle_factor) VFMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft16_sin_twiddle_factor) VFNMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft16_sin_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # 2x FFT8: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x128(ymm_real[0], ymm_real[1]) transpose2x2x128(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x2 x3 x8 x9 x10 x11 # w[1] = x4 x5 x6 x7 x12 x13 x14 x15 for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) # 2x FFT8: Multiplication by twiddle factors ymm_fft8_cos_twiddle_factor, ymm_fft8_sin_twiddle_factor = YMMRegister(), YMMRegister() VMOVAPS(ymm_fft8_cos_twiddle_factor, Constant.float32x8(*(cos_npi_over_4 * 2))) VMOVAPS(ymm_fft8_sin_twiddle_factor, Constant.float32x8(*(sin_npi_over_4 * 2))) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real1, ymm_new_imag1 = YMMRegister(), YMMRegister() VMULPS(ymm_new_real1, ymm_real[1], ymm_fft8_cos_twiddle_factor) VMULPS(ymm_new_imag1, ymm_imag[1], ymm_fft8_cos_twiddle_factor) VFMADD231PS(ymm_new_real1, ymm_imag[1], ymm_fft8_sin_twiddle_factor) VFNMADD231PS(ymm_new_imag1, ymm_real[1], ymm_fft8_sin_twiddle_factor) SWAP.REGISTERS(ymm_real[1], ymm_new_real1) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag1) # 4x FFT4: Butterfly for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): transpose2x2x2x64(ymm_real[0], ymm_real[1]) transpose2x2x2x64(ymm_imag[0], ymm_imag[1]) # w[0] = x0 x1 x4 x5 x8 x9 x12 x13 # w[1] = x2 x3 x6 x7 x10 x11 x14 x15 for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): butterfly(ymm_real[0], ymm_real[1]) butterfly(ymm_imag[0], ymm_imag[1]) # 4x FFT4: Multiplication by twiddle factors and 8x FFT2: Butterfly ymm_fft4_twiddle_factor = YMMRegister() VMOVAPS(ymm_fft4_twiddle_factor, Constant.float32x8(+1.0, +1.0, -1.0, -1.0, +1.0, +1.0, -1.0, -1.0)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): ymm_new_real = YMMRegister(), YMMRegister() VSHUFPS(ymm_new_real[0], ymm_real[0], ymm_real[1], _MM_SHUFFLE(2, 0, 2, 0)) VSHUFPS(ymm_new_real[1], ymm_real[0], ymm_imag[1], _MM_SHUFFLE(3, 1, 3, 1)) butterfly(ymm_new_real[0], ymm_new_real[1]) ymm_new_imag = YMMRegister(), YMMRegister() VSHUFPS(ymm_new_imag[0], ymm_imag[0], ymm_imag[1], _MM_SHUFFLE(2, 0, 2, 0)) VSHUFPS(ymm_new_imag[1], ymm_imag[0], ymm_real[1], _MM_SHUFFLE(3, 1, 3, 1)) butterfly(ymm_new_imag[0], ymm_new_imag[1], scale_b=ymm_fft4_twiddle_factor) SWAP.REGISTERS(ymm_real[0], ymm_new_real[0]) SWAP.REGISTERS(ymm_real[1], ymm_new_real[1]) SWAP.REGISTERS(ymm_imag[0], ymm_new_imag[0]) SWAP.REGISTERS(ymm_imag[1], ymm_new_imag[1]) # w[0] = x0 x4 x2 x6 x8 x12 x10 x14 # w[1] = x1 x5 x3 x7 x9 x11 x13 x15 if bit_reversal: # Bit reversal ymm_bit_reversal_mask = YMMRegister() VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 4, 1, 5, 2, 6, 3, 7)) for ymm_real, ymm_imag in zip(ymm_real_rows, ymm_imag_rows): for i in range(2): VPERMPS(ymm_real[i], ymm_bit_reversal_mask, ymm_real[i]) VPERMPS(ymm_imag[i], ymm_bit_reversal_mask, ymm_imag[i])
data = [YMMRegister() for _ in range(16)] data[0] = LocalVariable(data[0]) data[8] = LocalVariable(data[8]) real, imag = data[0::2], data[1::2] for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])): ymm_data_lo, ymm_data_hi = data_lo, data_hi if isinstance(data_lo, LocalVariable): ymm_data_lo = YMMRegister() if isinstance(data_hi, LocalVariable): ymm_data_hi = YMMRegister() VMOVUPS(ymm_data_lo, [reg_t + i * YMMRegister.size]) VMOVUPS(ymm_data_hi, [reg_t + (i + 8) * YMMRegister.size]) butterfly(ymm_data_lo, ymm_data_hi) if isinstance(data_lo, LocalVariable): VMOVAPS(data_lo, ymm_data_lo) if isinstance(data_hi, LocalVariable): VMOVAPS(data_hi, ymm_data_hi) # FFT8: multiplication by twiddle factors fft4_scale_b, fft4_negate_b = {}, {} fft2_scale_b, fft2_negate_b = {}, {} # w6.re, w6.im = w6.im, -w6.re SWAP.REGISTERS(real[6], imag[6]) fft4_negate_b[id(imag[6])] = True # w5.re, w5.im = SQRT2_OVER_2 * (w5.re + w5.im), SQRT2_OVER_2 * (w5.im - w5.re)
def fft8_across_rows(ymm_data): assert isinstance(ymm_data, list) and len(ymm_data) == 8 ymm_real = ymm_data[0::2] ymm_imag = ymm_data[1::2] fft.complex_soa.fft4_across_rows(ymm_real, ymm_imag) butterfly(ymm_real[0], ymm_imag[0]) # const float two_gdata1_real = crealf(data1) + crealf(data3); # const float two_gdata1_imag = cimagf(data1) - cimagf(data3); ymm_two_gdata1_real, ymm_two_gdata1_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_gdata1_real, ymm_real[1], ymm_real[3]) VSUBPS(ymm_two_gdata1_imag, ymm_imag[1], ymm_imag[3]) # const float two_hdata1_real = cimagf(data1) + cimagf(data3); # const float two_hdata1_imag = crealf(data3) - crealf(data1); ymm_two_hdata1_real, ymm_two_hdata1_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_hdata1_real, ymm_imag[1], ymm_imag[3]) VSUBPS(ymm_two_hdata1_imag, ymm_real[3], ymm_real[1]) # const float two_hdata1_real_plus_imag = two_hdata1_real + two_hdata1_imag; # const float two_hdata1_real_minus_imag = two_hdata1_real - two_hdata1_imag; ymm_two_hdata1_plus, ymm_two_hdata1_minus = YMMRegister(), YMMRegister() VADDPS(ymm_two_hdata1_plus, ymm_two_hdata1_real, ymm_two_hdata1_imag) VSUBPS(ymm_two_hdata1_minus, ymm_two_hdata1_real, ymm_two_hdata1_imag) ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) # const float two_data1_real = two_gdata1_real + SQRT2_OVER_2 * two_hdata1_real_plus_imag; # const float two_data1_imag = two_gdata1_imag - SQRT2_OVER_2 * two_hdata1_real_minus_imag; # const float two_data3_real = two_gdata1_real - SQRT2_OVER_2 * two_hdata1_real_plus_imag; # const float two_data3_imag = -two_gdata1_imag - SQRT2_OVER_2 * two_hdata1_real_minus_imag; ymm_two_data1_real, ymm_two_data1_imag = YMMRegister(), YMMRegister() ymm_two_data3_real, ymm_two_data3_imag = YMMRegister(), YMMRegister() VMOVAPS(ymm_two_data3_real, ymm_two_gdata1_real) VMOVAPS(ymm_two_data3_imag, ymm_two_gdata1_imag) VFMADD231PS(ymm_two_gdata1_real, ymm_two_hdata1_plus, ymm_sqrt2_over_2) VFNMADD231PS(ymm_two_gdata1_imag, ymm_two_hdata1_minus, ymm_sqrt2_over_2) SWAP.REGISTERS(ymm_two_data1_real, ymm_two_gdata1_real) SWAP.REGISTERS(ymm_two_data1_imag, ymm_two_gdata1_imag) VFNMADD231PS(ymm_two_data3_real, ymm_two_hdata1_plus, ymm_sqrt2_over_2) VFNMSUB231PS(ymm_two_data3_imag, ymm_two_hdata1_minus, ymm_sqrt2_over_2) # /* Store outputs */ # fdata[0] = crealf(data0) + cimagf(data0); # fdata[1] = crealf(data0) - cimagf(data0); # fdata[2] = 0.5f * two_data1_real; # fdata[3] = 0.5f * two_data1_imag; # fdata[4] = crealf(data2); # fdata[5] = -cimagf(data2); # fdata[6] = 0.5f * two_data3_real; # fdata[7] = 0.5f * two_data3_imag; ymm_half = YMMRegister() VMOVAPS(ymm_half, Constant.float32x8(0.5)) VMULPS(ymm_real[1], ymm_two_data1_real, ymm_half) VMULPS(ymm_imag[1], ymm_two_data1_imag, ymm_half) VXORPS(ymm_imag[2], ymm_imag[2], Constant.float32x8(-0.0)) VMULPS(ymm_real[3], ymm_two_data3_real, ymm_half) VMULPS(ymm_imag[3], ymm_two_data3_imag, ymm_half)
def fft8_across_rows(ymm_data): assert isinstance(ymm_data, list) and len(ymm_data) == 8 ymm_real = ymm_data[0::2] ymm_imag = ymm_data[1::2] complex_soa.fft4_across_rows(ymm_real, ymm_imag) butterfly(ymm_real[0], ymm_imag[0]) # const float two_gdata1_real = crealf(data1) + crealf(data3); # const float two_gdata1_imag = cimagf(data1) - cimagf(data3); ymm_two_gdata1_real, ymm_two_gdata1_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_gdata1_real, ymm_real[1], ymm_real[3]) VSUBPS(ymm_two_gdata1_imag, ymm_imag[1], ymm_imag[3]) # const float two_hdata1_real = cimagf(data1) + cimagf(data3); # const float two_hdata1_imag = crealf(data3) - crealf(data1); ymm_two_hdata1_real, ymm_two_hdata1_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_hdata1_real, ymm_imag[1], ymm_imag[3]) VSUBPS(ymm_two_hdata1_imag, ymm_real[3], ymm_real[1]) # const float two_hdata1_real_plus_imag = two_hdata1_real + two_hdata1_imag; # const float two_hdata1_real_minus_imag = two_hdata1_real - two_hdata1_imag; ymm_two_hdata1_plus, ymm_two_hdata1_minus = YMMRegister(), YMMRegister() VADDPS(ymm_two_hdata1_plus, ymm_two_hdata1_real, ymm_two_hdata1_imag) VSUBPS(ymm_two_hdata1_minus, ymm_two_hdata1_real, ymm_two_hdata1_imag) ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) # const float two_data1_real = two_gdata1_real + SQRT2_OVER_2 * two_hdata1_real_plus_imag; # const float two_data1_imag = two_gdata1_imag - SQRT2_OVER_2 * two_hdata1_real_minus_imag; # const float two_data3_real = two_gdata1_real - SQRT2_OVER_2 * two_hdata1_real_plus_imag; # const float two_data3_imag = -two_gdata1_imag - SQRT2_OVER_2 * two_hdata1_real_minus_imag; ymm_two_data1_real, ymm_two_data1_imag = YMMRegister(), YMMRegister() ymm_two_data3_real, ymm_two_data3_imag = YMMRegister(), YMMRegister() VMOVAPS(ymm_two_data3_real, ymm_two_gdata1_real) VMOVAPS(ymm_two_data3_imag, ymm_two_gdata1_imag) VFMADD231PS(ymm_two_gdata1_real, ymm_two_hdata1_plus, ymm_sqrt2_over_2) VFNMADD231PS(ymm_two_gdata1_imag, ymm_two_hdata1_minus, ymm_sqrt2_over_2) SWAP.REGISTERS(ymm_two_data1_real, ymm_two_gdata1_real) SWAP.REGISTERS(ymm_two_data1_imag, ymm_two_gdata1_imag) VFNMADD231PS(ymm_two_data3_real, ymm_two_hdata1_plus, ymm_sqrt2_over_2) VFNMSUB231PS(ymm_two_data3_imag, ymm_two_hdata1_minus, ymm_sqrt2_over_2) # /* Store outputs */ # fdata[0] = crealf(data0) + cimagf(data0); # fdata[1] = crealf(data0) - cimagf(data0); # fdata[2] = 0.5f * two_data1_real; # fdata[3] = 0.5f * two_data1_imag; # fdata[4] = crealf(data2); # fdata[5] = -cimagf(data2); # fdata[6] = 0.5f * two_data3_real; # fdata[7] = 0.5f * two_data3_imag; ymm_half = YMMRegister() VMOVAPS(ymm_half, Constant.float32x8(0.5)) VMULPS(ymm_real[1], ymm_two_data1_real, ymm_half) VMULPS(ymm_imag[1], ymm_two_data1_imag, ymm_half) VXORPS(ymm_imag[2], ymm_imag[2], Constant.float32x8(-0.0)) VMULPS(ymm_real[3], ymm_two_data3_real, ymm_half) VMULPS(ymm_imag[3], ymm_two_data3_imag, ymm_half)