def env_uniform(etime): """ Uniform cresendo or diminuendo; """ dests = [0., random() * .1 + .05, random() * .25 + .75, 0.] times = [0., 30., etime * 1000. - 60., 30.] curves = [0.] + window([.85, .85]) + [0.] if random() < .5: return interleave(dests, times, curves) return interleave(dests[-1::-1], [0.] + times[-1:0:-1], curves)
def env_jagged(etime): """ Jagged envelope: Picks at most 4, at least 2 events for each second of the event time. Each event is categorized by a random amplitude value and time distribution w/ a curve. """ dests = [0.] times = [0.] time_error = 0. loctime = etime while loctime > 0: if loctime > 1.: correct = 1 else: correct = loctime numevents = int((random() * 3. + 2.) * loctime) dests += [random() for i in range(numevents)] time_add = [random() / numevents for i in range(numevents)] next_error = 1. - sum(time_add) if len(time_add) > 0: time_add[0] += time_error times += time_add time_error = next_error loctime -= 1. dests[len(dests) - 1] = 0. times[len(times) - 1] += etime - sum(times) times = [1000. * t for t in times] curves = [0.] + [window(.75) for i in range(1, len(dests) - 1)] + [0.] return interleave(dests, times, curves)
def pitch_unidir(time, correct): """ Unidirectional pitch-shift into a specified pitch. Starts at most 3 semitones away from the desired pitch. """ trans = ratio(window(1.) + correct) origin = ratio(correct) return interleave([trans, origin], [0., time * 1000], [1, window(.75)])
def amp_in(time): """ A double-attack crescendo. """ ms = lambda t: _ms(t, time) first_time = random() * .35 + .2 second_time = random() * .3 + .2 decay_time = 1 - first_time - second_time first_val = random() * .4 + .4 dest = [0., first_val, 1., 0.] times = [0., ms(first_time), ms(second_time), ms(decay_time)] curves = [1.] + window([.75, .75, .75]) return interleave(dest, times, curves)
def amp_env(time): """ A simple AHD envelope. 20 ms ramp up, hold for most of the time, 20 ms ramp down. """ ms = lambda t: _ms(t, time) attack = decay = 20 # ms hold = ms(time) - 40 # ms sustain = random() * 0.5 + 0.5 dest = [0.0, sustain, sustain, 0.0] times = [0.0, attack, hold, decay] return interleave(dest, times)
def pitch_peak(time, correct): """ A triangle-wave-like pitch shift. Moves up / down from no-transposition and returns to normal. Guaranteed to reach its destination within the first 75% of the time requested. """ ms = lambda t: _ms(t, time) trans = ratio(window(1.) + correct) origin = ratio(correct) attack_time = random() * .75 + .2 decay_time = 1. - attack_time dest = [origin, trans, origin] times = [0., ms(attack_time), ms(decay_time)] curves = [1.] + window([.75, .75]) return interleave(dest, times, curves)
def amp_in_out(time): """ An AHD envelope. Attack (in the first 75 percent of the time) Hold @ sustain value (20 percent of the time) Sustain value is in between .5 and 1. Decay to 0. for the time remaining """ ms = lambda t: _ms(t, time) attack_time = random() * .55 + .2 sustain_val = random() * .5 + .5 hold_time = random() * .2 decay_time = 1. - attack_time - hold_time dest = [0., sustain_val, sustain_val, 0.] times = [0., ms(attack_time), ms(hold_time), ms(decay_time)] curves = [1., window(.75), 1., window(.75)] return interleave(dest, times, curves)
def pitch_cross(time, correct): """ Shifts up then down (or down then up), before returning to the original pitch of the sample. """ ms = lambda t: _ms(t, time) _ts = lambda: random() * .33 + .1 shifts = window([1., 1.]) if (shifts[0] * shifts[1]) > 0: shifts[one([0, 1])] *= -1 trans = [] for s in shifts: trans.append(s + correct) trans = ratio(trans) first = _ts() second = 1 - _ts() - first decay = 1. - (first + second) origin = ratio(correct) dest = [origin] + trans + [origin] times = [0., ms(first), ms(second), ms(decay)] curves = [1.] + window([.75, .75, .75]) return interleave(dest, times, curves)
def ifft16_within_rows_preprocess(ymm_wr, ymm_wi, bit_reversal=False): assert isinstance(ymm_wr, (list, tuple)) and len(ymm_wr) == 2 and all( isinstance(reg, YMMRegister) for reg in ymm_wr) assert isinstance(ymm_wi, (list, tuple)) and len(ymm_wi) == 2 and all( isinstance(reg, YMMRegister) for reg in ymm_wi) # w0.re, w1.re, w2.re, w3.re, w4.re, w5.re, w6.re, w7.re = \ # = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, f8.re - f9.im, f10.re - f11.im, f12.re - f13.im, f14.re - f15.im # w8.re, w9.re, w10.re, w11.re, w12.re, w13.re, w14.re, w15.re = \ # = f0.im, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, f8.re + f9.im, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im # # w0.im, w1.im, w2.im, w3.im, w4.im, w5.im, w6.im, w7.im = \ # = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im # w8.im, w9.im, w10.im, w11.im, w12.im, w13.im, w14.im, w15.im = \ # = f1.im, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, f9.re - f8.im, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im # Step 1.A: # w0.re, w1.re, w2.re, w3.re, -, w13.re, w14.re, w15.re = \ # = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, -, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im # w0.im, w1.im, w2.im, w3.im, -, w13.im, w14.im, w15.im = \ # = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, -, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im ymm_shuffle_02461642 = YMMRegister() VMOVDQA(ymm_shuffle_02461642, Constant.uint32x8(0, 2, 4, 6, 1, 6, 4, 2)) ymm_shuffle_13570753 = YMMRegister() VMOVDQA(ymm_shuffle_13570753, Constant.uint32x8(1, 3, 5, 7, 0, 7, 5, 3)) ymm_fr_02461642, ymm_fi_13570753 = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_02461642, ymm_shuffle_02461642, ymm_wr[0]) VPERMPS(ymm_fi_13570753, ymm_shuffle_13570753, ymm_wi[0]) VFMADD231PS( ymm_fr_02461642, ymm_fi_13570753, Constant.float32x8(0.0, -1.0, -1.0, -1.0, 0.0, +1.0, +1.0, +1.0)) ymm_fr_13570753, ymm_fi_02461642 = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_13570753, ymm_shuffle_13570753, ymm_wr[0]) VPERMPS(ymm_fi_02461642, ymm_shuffle_02461642, ymm_wi[0]) VFMADD231PS( ymm_fr_13570753, ymm_fi_02461642, Constant.float32x8(0.0, +1.0, +1.0, +1.0, 0.0, -1.0, -1.0, -1.0)) ymm_wr_0123xDEF, ymm_wi_0123xDEF = ymm_fr_02461642, ymm_fr_13570753 # Step 1.B: # -, w9.re, w10.re, w11.re, w4.re, w5.re, w6.re, w7.re = \ # = -, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, r8.re - r9.im, r10.re - r11.im, r12.re - r13.im, r14.re - f15.im # -, w9.im, w10.im, w11.im, w4.im, w5.im, w6.im, w7.im = \ # = -, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im ymm_shuffle_06420246 = YMMRegister() VMOVDQA(ymm_shuffle_06420246, Constant.uint32x8(0, 6, 4, 2, 0, 2, 4, 6)) ymm_shuffle_17531357 = YMMRegister() VMOVDQA(ymm_shuffle_17531357, Constant.uint32x8(1, 7, 5, 3, 1, 3, 5, 7)) ymm_wr_xxxxCxxx, ymm_wi_xxxxCxxx = YMMRegister(), YMMRegister() ymm_wr_0123CDEF, ymm_wi_0123CDEF = YMMRegister(), YMMRegister() ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_8ECA8ACE, ymm_shuffle_06420246, ymm_wr[1]) VPERMPS(ymm_fi_9FDB9BDF, ymm_shuffle_17531357, ymm_wi[1]) VADDPS(ymm_wr_xxxxCxxx, ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF) VFMADD231PS( ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF, Constant.float32x8(0.0, +1.0, +1.0, +1.0, -1.0, -1.0, -1.0, -1.0)) VBLENDPS(ymm_wr_0123CDEF, ymm_wr_0123xDEF, ymm_wr_xxxxCxxx, 0b00010000) ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_9FDB9BDF, ymm_shuffle_17531357, ymm_wr[1]) VPERMPS(ymm_fi_8ECA8ACE, ymm_shuffle_06420246, ymm_wi[1]) VSUBPS(ymm_wi_xxxxCxxx, ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE) VFMADD231PS( ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE, Constant.float32x8(0.0, -1.0, -1.0, -1.0, +1.0, +1.0, +1.0, +1.0)) VBLENDPS(ymm_wi_0123CDEF, ymm_wi_0123xDEF, ymm_wi_xxxxCxxx, 0b00010000) ymm_wr_x9AB4567, ymm_wi_x9AB4567 = ymm_fr_8ECA8ACE, ymm_fr_9FDB9BDF ymm_wr_89AB4567, ymm_wi_89AB4567 = YMMRegister(), YMMRegister() VBLENDPS(ymm_wr_89AB4567, ymm_wr_x9AB4567, ymm_fi_02461642, 0b00000001) VBLENDPS(ymm_wi_89AB4567, ymm_wi_x9AB4567, ymm_fi_13570753, 0b00000001) ymm_wr_01234567, ymm_wr_89ABCDEF = YMMRegister(), YMMRegister() VBLENDPS(ymm_wr_01234567, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0xF0) VBLENDPS(ymm_wr_89ABCDEF, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0x0F) ymm_wi_01234567, ymm_wi_89ABCDEF = YMMRegister(), YMMRegister() VBLENDPS(ymm_wi_01234567, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0xF0) VBLENDPS(ymm_wi_89ABCDEF, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0x0F) SWAP.REGISTERS(ymm_wr[0], ymm_wr_01234567) SWAP.REGISTERS(ymm_wi[0], ymm_wi_01234567) SWAP.REGISTERS(ymm_wr[1], ymm_wr_89ABCDEF) SWAP.REGISTERS(ymm_wi[1], ymm_wi_89ABCDEF) if bit_reversal: # Bit reversal # w[0] = x0 x8 x4 x12 x2 x10 x6 x14 # w[1] = x1 x9 x5 x13 x3 x11 x7 x15 ymm_bit_reversal_mask = YMMRegister() VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 2, 4, 6, 1, 3, 5, 7)) for ymm in interleave(ymm_wr, ymm_wi): VPERMPS(ymm, ymm_bit_reversal_mask, ymm)
def forward_vfft(reg_t0, reg_t8, reg_t_stride, data_out, reg_row_start=None, reg_row_end=None, ymm_load_mask=None): assert isinstance(reg_t0, GeneralPurposeRegister64) assert isinstance(reg_t8, GeneralPurposeRegister64) assert isinstance(reg_t_stride, GeneralPurposeRegister64) assert isinstance(data_out, list) and len(data_out) == 16 assert ymm_load_mask is None or isinstance(ymm_load_mask, YMMRegister) out_real, out_imag = data_out[0::2], data_out[1::2] real, imag = [YMMRegister() for _ in range(8)], [YMMRegister() for _ in range(8)] imag[0] = LocalVariable(YMMRegister.size) imag[4] = LocalVariable(YMMRegister.size) data = interleave(real, imag) for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])): row_lo = i row_hi = row_lo + 8 ymm_data_lo, ymm_data_hi = data_lo, data_hi if isinstance(data_lo, LocalVariable): ymm_data_lo = YMMRegister() if isinstance(data_hi, LocalVariable): ymm_data_hi = YMMRegister() VXORPS(ymm_data_lo, ymm_data_lo, ymm_data_lo) skip_data_lo = Label() if reg_row_start: CMP(reg_row_start, row_lo) JA(skip_data_lo) if reg_row_end: CMP(reg_row_end, row_lo) JBE(skip_data_lo) if ymm_load_mask is None: VMOVUPS(ymm_data_lo, [reg_t0]) else: VMASKMOVPS(ymm_data_lo, ymm_load_mask, [reg_t0]) if i + 1 != 8: ADD(reg_t0, reg_t_stride) LABEL(skip_data_lo) VMOVAPS(ymm_data_hi, ymm_data_lo) skip_data_hi = Label() if reg_row_start: CMP(reg_row_start, row_hi) JA(skip_data_hi) if reg_row_end: CMP(reg_row_end, row_hi) JBE(skip_data_hi) if ymm_load_mask is None: VMOVUPS(ymm_data_hi, [reg_t8]) butterfly(ymm_data_lo, ymm_data_hi) else: ymm_temp_hi = YMMRegister() VMASKMOVPS(ymm_temp_hi, ymm_load_mask, [reg_t8]) VSUBPS(ymm_data_hi, ymm_data_lo, ymm_temp_hi) VADDPS(ymm_data_lo, ymm_data_lo, ymm_temp_hi) if i + 1 != 8: ADD(reg_t8, reg_t_stride) LABEL(skip_data_hi) if isinstance(data_lo, LocalVariable): VMOVAPS(data_lo, ymm_data_lo) if isinstance(data_hi, LocalVariable): VMOVAPS(data_hi, ymm_data_hi) # FFT8: multiplication by twiddle factors fft4_scale_b, fft4_negate_b = {}, {} fft2_scale_b, fft2_negate_b = {}, {} # w6.re, w6.im = w6.im, -w6.re SWAP.REGISTERS(real[6], imag[6]) fft4_negate_b[id(imag[6])] = True # w5.re, w5.im = SQRT2_OVER_2 * (w5.re + w5.im), SQRT2_OVER_2 * (w5.im - w5.re) butterfly(imag[5], real[5]) SWAP.REGISTERS(real[5], imag[5]) # w7.re, w7.im = -SQRT2_OVER_2 * (w7.re - w7.im), -SQRT2_OVER_2 * (w7.re + w7.im) butterfly(real[7], imag[7], negate_b=True) fft4_negate_b[id(real[7])] = True fft4_negate_b[id(imag[7])] = True # Propogate multiplication by sqrt2_over_2 until the last butterfly in FFT2 ymm_sqrt2_over_2 = YMMRegister() fft2_scale_b[id(real[5])] = ymm_sqrt2_over_2 fft2_scale_b[id(imag[5])] = ymm_sqrt2_over_2 fft2_scale_b[id(real[7])] = ymm_sqrt2_over_2 fft2_scale_b[id(imag[7])] = ymm_sqrt2_over_2 # 2x FFT4: butterfly for data_lo, data_hi in zip(data[0:4] + data[8:12], data[4:8] + data[12:16]): butterfly(data_lo, data_hi, negate_b=fft4_negate_b.get(id(data_hi), False), scale_b=fft4_scale_b.get(id(data_hi))) # 2x FFT4: multiplication by twiddle factors # w3.re, w3.im = w3.im, -w3.re # w7.re, w7.im = w7.im, -w7.re SWAP.REGISTERS(real[3], imag[3]) SWAP.REGISTERS(real[7], imag[7]) fft2_negate_b[id(imag[3])] = True fft2_negate_b[id(imag[7])] = True # 4x FFT2: butterfly # Process the first two elements separately ymm_real0, ymm_real1 = butterfly(real[0], real[1], writeback=False) store_ymm_result(out_real[4], ymm_real1) # bit-reversal: 1->4 ymm_imag0, ymm_imag1 = butterfly(imag[0], imag[1], negate_out_b=True, writeback=False) store_ymm_result(out_imag[4], ymm_imag1) # bit-reversal: 1->4 VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) for i, (data_lo, data_hi) in enumerate(zip(data[4:6] + data[8:10] + data[12:14], data[6:8] + data[10:12] + data[14:16])): butterfly(data_lo, data_hi, negate_b=fft2_negate_b.get(id(data_hi), False), scale_b=fft2_scale_b.get(id(data_hi))) butterfly(ymm_real0, ymm_imag0) store_ymm_result(out_real[0], ymm_real0) store_ymm_result(out_imag[0], ymm_imag0) # Bit reversal for i in range(8): new_i = fft8_bitreverse(i) if new_i > i: real[i], real[new_i] = real[new_i], real[i] imag[i], imag[new_i] = imag[new_i], imag[i] data = interleave(real, imag) ymm_two_g2_real, ymm_two_g2_imag = YMMRegister(), YMMRegister() ymm_two_h2_real, ymm_two_h2_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_g2_real, real[2], real[6]) VSUBPS(ymm_two_h2_imag, real[6], real[2]) VSUBPS(ymm_two_g2_imag, imag[2], imag[6]) VADDPS(ymm_two_h2_real, imag[2], imag[6]) ymm_two_g1_real, ymm_two_g1_imag = YMMRegister(), YMMRegister() ymm_two_h1_real, ymm_two_h1_imag = YMMRegister(), YMMRegister() ymm_real1 = load_ymm_variable(real[1]) VADDPS(ymm_two_g1_real, ymm_real1, real[7]) VSUBPS(ymm_two_h1_imag, real[7], ymm_real1) ymm_imag1 = load_ymm_variable(imag[1]) VSUBPS(ymm_two_g1_imag, ymm_imag1, imag[7]) VADDPS(ymm_two_h1_real, ymm_imag1, imag[7]) ymm_two_h2_add, ymm_two_h2_sub = YMMRegister(), YMMRegister() VADDPS(ymm_two_h2_add, ymm_two_h2_real, ymm_two_h2_imag) VSUBPS(ymm_two_h2_sub, ymm_two_h2_imag, ymm_two_h2_real) ymm_two_g3_real, ymm_two_g3_imag = YMMRegister(), YMMRegister() ymm_two_h3_real, ymm_two_h3_imag = YMMRegister(), YMMRegister() VADDPS(ymm_two_g3_real, real[3], real[5]) VSUBPS(ymm_two_h3_imag, real[5], real[3]) VSUBPS(ymm_two_g3_imag, imag[3], imag[5]) VADDPS(ymm_two_h3_real, imag[3], imag[5]) # const float two_w2_real = two_g2_real + SQRT2_OVER_2 * (two_h2_real + two_h2_imag); # const float two_w2_imag = two_g2_imag + SQRT2_OVER_2 * (two_h2_imag - two_h2_real); # const float two_w6_real = two_g2_real - SQRT2_OVER_2 * (two_h2_real + two_h2_imag); # const float two_w6_imag = -two_g2_imag + SQRT2_OVER_2 * (two_h2_imag - two_h2_real); ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) ymm_two_w2_real, ymm_two_w6_real = YMMRegister(), ymm_two_g2_real VMOVAPS(ymm_two_w2_real, ymm_two_g2_real) VFMADD231PS(ymm_two_w2_real, ymm_two_h2_add, ymm_sqrt2_over_2) VFNMADD231PS(ymm_two_w6_real, ymm_two_h2_add, ymm_sqrt2_over_2) ymm_two_w2_imag, ymm_two_w6_imag = YMMRegister(), ymm_two_g2_imag VMOVAPS(ymm_two_w2_imag, ymm_two_g2_imag) VFMADD231PS(ymm_two_w2_imag, ymm_two_h2_sub, ymm_sqrt2_over_2) VFMSUB231PS(ymm_two_w6_imag, ymm_two_h2_sub, ymm_sqrt2_over_2) ymm_half = YMMRegister() VMOVAPS(ymm_half, Constant.float32x8(0.5)) VMULPS(ymm_two_w2_real, ymm_two_w2_real, ymm_half) store_ymm_result(out_real[2], ymm_two_w2_real) VMULPS(ymm_two_w6_real, ymm_two_w6_real, ymm_half) store_ymm_result(out_real[6], ymm_two_w6_real) VMULPS(ymm_two_w2_imag, ymm_two_w2_imag, ymm_half) store_ymm_result(out_imag[2], ymm_two_w2_imag) VMULPS(ymm_two_w6_imag, ymm_two_w6_imag, ymm_half) store_ymm_result(out_imag[6], ymm_two_w6_imag) # const float two_w1_real = two_g1_real + two_h1_real * COS_1PI_OVER_8 + two_h1_imag * COS_3PI_OVER_8; # const float two_w1_imag = two_g1_imag + two_h1_imag * COS_1PI_OVER_8 - two_h1_real * COS_3PI_OVER_8; # const float two_w7_real = two_g1_real - two_h1_real * COS_1PI_OVER_8 - two_h1_imag * COS_3PI_OVER_8; # const float two_w7_imag = -two_g1_imag + two_h1_imag * COS_1PI_OVER_8 - two_h1_real * COS_3PI_OVER_8; # const float two_w3_real = two_g3_real + two_h3_real * COS_3PI_OVER_8 + two_h3_imag * COS_1PI_OVER_8; # const float two_w3_imag = two_g3_imag + two_h3_imag * COS_3PI_OVER_8 - two_h3_real * COS_1PI_OVER_8; # const float two_w5_real = two_g3_real - two_h3_real * COS_3PI_OVER_8 - two_h3_imag * COS_1PI_OVER_8; # const float two_w5_imag = -two_g3_imag + two_h3_imag * COS_3PI_OVER_8 - two_h3_real * COS_1PI_OVER_8; ymm_cos_1pi_over_8 = YMMRegister() VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1])) ymm_two_w1_real, ymm_two_w7_real = YMMRegister(), ymm_two_g1_real VMOVAPS(ymm_two_w1_real, ymm_two_g1_real) VFMADD231PS(ymm_two_w1_real, ymm_two_h1_real, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_two_w7_real, ymm_two_h1_real, ymm_cos_1pi_over_8) ymm_two_w1_imag, ymm_two_w7_imag = YMMRegister(), ymm_two_g1_imag VMOVAPS(ymm_two_w1_imag, ymm_two_g1_imag) VFMADD231PS(ymm_two_w1_imag, ymm_two_h1_imag, ymm_cos_1pi_over_8) VFMSUB231PS(ymm_two_w7_imag, ymm_two_h1_imag, ymm_cos_1pi_over_8) ymm_two_w3_real, ymm_two_w5_real = YMMRegister(), ymm_two_g3_real VMOVAPS(ymm_two_w3_real, ymm_two_g3_real) VFMADD231PS(ymm_two_w3_real, ymm_two_h3_imag, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_two_w5_real, ymm_two_h3_imag, ymm_cos_1pi_over_8) ymm_two_w3_imag, ymm_two_w5_imag = YMMRegister(), ymm_two_g3_imag VMOVAPS(ymm_two_w3_imag, ymm_two_g3_imag) VFNMADD231PS(ymm_two_w3_imag, ymm_two_h3_real, ymm_cos_1pi_over_8) VFNMSUB231PS(ymm_two_w5_imag, ymm_two_h3_real, ymm_cos_1pi_over_8) ymm_cos_3pi_over_8 = YMMRegister() VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3])) VFMADD231PS(ymm_two_w1_real, ymm_two_h1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w7_real, ymm_two_h1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w1_imag, ymm_two_h1_real, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w7_imag, ymm_two_h1_real, ymm_cos_3pi_over_8) VFMADD231PS(ymm_two_w3_real, ymm_two_h3_real, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_two_w5_real, ymm_two_h3_real, ymm_cos_3pi_over_8) VFMADD231PS(ymm_two_w3_imag, ymm_two_h3_imag, ymm_cos_3pi_over_8) VFMADD231PS(ymm_two_w5_imag, ymm_two_h3_imag, ymm_cos_3pi_over_8) ymm_half = YMMRegister() VMOVAPS(ymm_half, Constant.float32x8(0.5)) VMULPS(ymm_two_w1_real, ymm_two_w1_real, ymm_half) store_ymm_result(out_real[1], ymm_two_w1_real) VMULPS(ymm_two_w7_real, ymm_two_w7_real, ymm_half) store_ymm_result(out_real[7], ymm_two_w7_real) VMULPS(ymm_two_w1_imag, ymm_two_w1_imag, ymm_half) store_ymm_result(out_imag[1], ymm_two_w1_imag) VMULPS(ymm_two_w7_imag, ymm_two_w7_imag, ymm_half) store_ymm_result(out_imag[7], ymm_two_w7_imag) VMULPS(ymm_two_w3_real, ymm_two_w3_real, ymm_half) store_ymm_result(out_real[3], ymm_two_w3_real) VMULPS(ymm_two_w5_real, ymm_two_w5_real, ymm_half) store_ymm_result(out_real[5], ymm_two_w5_real) VMULPS(ymm_two_w3_imag, ymm_two_w3_imag, ymm_half) store_ymm_result(out_imag[3], ymm_two_w3_imag) VMULPS(ymm_two_w5_imag, ymm_two_w5_imag, ymm_half) store_ymm_result(out_imag[5], ymm_two_w5_imag)
def inverse_vfft(reg_t0, reg_t8, reg_t_stride, data_in, reg_row_start=None, reg_row_end=None, store_mask=None): assert isinstance(reg_t0, GeneralPurposeRegister64) assert isinstance(reg_t8, GeneralPurposeRegister64) assert isinstance(reg_t_stride, GeneralPurposeRegister64) assert isinstance(data_in, list) and len(data_in) == 16 assert reg_row_end is None or isinstance(reg_row_end, GeneralPurposeRegister32) assert store_mask is None or isinstance(store_mask, LocalVariable) and store_mask.size == YMMRegister.size in_real, in_imag = data_in[0::2], data_in[1::2] ymm_scale_factor = YMMRegister() VMOVAPS(ymm_scale_factor, Constant.float32x8(0.0625)) ymm_W1_real, ymm_W1_imag = YMMRegister(), YMMRegister() VMULPS(ymm_W1_real, ymm_scale_factor, in_real[1]) VMULPS(ymm_W1_imag, ymm_scale_factor, in_imag[1]) ymm_W2_real, ymm_W2_imag = YMMRegister(), YMMRegister() VMULPS(ymm_W2_real, ymm_scale_factor, in_real[2]) VMULPS(ymm_W2_imag, ymm_scale_factor, in_imag[2]) ymm_W3_real, ymm_W3_imag = YMMRegister(), YMMRegister() VMULPS(ymm_W3_real, ymm_scale_factor, in_real[3]) VMULPS(ymm_W3_imag, ymm_scale_factor, in_imag[3]) # G[n].real, H[n].real = W[n].real + W[8-n].real, W[n].real - W[8-n].real # G[n].imag, H[n].imag = W[n].imag - W[8-n].imag, W[n].imag + W[8-n].imag ymm_W7_real, ymm_W7_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_W7_real, in_real[7]) ymm_G1_real, ymm_H1_real = butterfly(ymm_W1_real, ymm_W7_real, scale_b=ymm_scale_factor) VMOVUPS(ymm_W7_imag, in_imag[7]) ymm_G1_imag, ymm_H1_imag = butterfly(ymm_W1_imag, ymm_W7_imag, scale_b=ymm_scale_factor, negate_b=True) ymm_W6_real, ymm_W6_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_W6_real, in_real[6]) ymm_G2_real, ymm_H2_real = butterfly(ymm_W2_real, ymm_W6_real, scale_b=ymm_scale_factor) VMOVUPS(ymm_W6_imag, in_imag[6]) ymm_G2_imag, ymm_H2_imag = butterfly(ymm_W2_imag, ymm_W6_imag, scale_b=ymm_scale_factor, negate_b=True) ymm_W5_real, ymm_W5_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_W5_real, in_real[5]) ymm_G3_real, ymm_H3_real = butterfly(ymm_W3_real, ymm_W5_real, scale_b=ymm_scale_factor) VMOVUPS(ymm_W5_imag, in_imag[5]) ymm_G3_imag, ymm_H3_imag = butterfly(ymm_W3_imag, ymm_W5_imag, scale_b=ymm_scale_factor, negate_b=True) # H[2]+, H[2]- = H[2].real + H[2].imag, H[2].real - H[2].imag ymm_H2_add, ymm_H2_sub = butterfly(ymm_H2_real, ymm_H2_imag) # w[ n].real = G[ n].real - H[ n].real * cos((N-n)*pi/2N) - H[ n].imag * cos(n*pi/2N) # w[2N-n].real = G[ n].real + H[ n].real * cos((N-n)*pi/2N) + H[ n].imag * cos(n*pi/2N) # w[ n].imag = G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[2N-n].imag = -G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[ N-n].real = G[N-n].real - H[N-n].real * cos(n*pi/2N) - H[N-n].imag * cos((N-n)*pi/2N) # w[ N+n].real = G[N-n].real + H[N-n].real * cos(n*pi/2N) + H[N-n].imag * cos((N-n)*pi/2N) # w[ N-n].imag = G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) # w[ N+n].imag = -G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) ymm_cos_1pi_over_8, ymm_cos_3pi_over_8 = YMMRegister(), YMMRegister() VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3])) VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1])) ymm_w1_real, ymm_w7_real = YMMRegister(), ymm_G1_real VMOVAPS(ymm_w1_real, ymm_G1_real) VFNMADD231PS(ymm_w1_real, ymm_H1_real, ymm_cos_3pi_over_8) VFMADD231PS(ymm_w7_real, ymm_H1_real, ymm_cos_3pi_over_8) ymm_w1_imag, ymm_w7_imag = YMMRegister(), ymm_G1_imag VMOVAPS(ymm_w1_imag, ymm_G1_imag) VFMADD231PS(ymm_w1_imag, ymm_H1_real, ymm_cos_1pi_over_8) VFMSUB231PS(ymm_w7_imag, ymm_H1_real, ymm_cos_1pi_over_8) ymm_w3_real, ymm_w5_real = YMMRegister(), ymm_G3_real VMOVAPS(ymm_w3_real, ymm_G3_real) VFNMADD231PS(ymm_w3_real, ymm_H3_real, ymm_cos_1pi_over_8) VFMADD231PS(ymm_w5_real, ymm_H3_real, ymm_cos_1pi_over_8) ymm_w3_imag, ymm_w5_imag = YMMRegister(), ymm_G3_imag VMOVAPS(ymm_w3_imag, ymm_G3_imag) VFMADD231PS(ymm_w3_imag, ymm_H3_real, ymm_cos_3pi_over_8) VFMSUB231PS(ymm_w5_imag, ymm_H3_real, ymm_cos_3pi_over_8) ymm_sqrt2_over_2 = YMMRegister() VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2)) # w[ N/2].real = G[N/2].real - H[N/2]+ * sqrt(2)/2 # w[ N/2].imag = G[N/2].imag + H[N/2]- * sqrt(2)/2 # w[3N/2].real = G[N/2].real + H[N/2]+ * sqrt(2)/2 # w[3N/2].imag = -G[N/2].imag + H[N/2]- * sqrt(2)/2 ymm_w2_real, ymm_w6_real = YMMRegister(), ymm_G2_real VMOVAPS(ymm_w2_real, ymm_G2_real) VFNMADD231PS(ymm_w2_real, ymm_H2_add, ymm_sqrt2_over_2) VFMADD231PS(ymm_w6_real, ymm_H2_add, ymm_sqrt2_over_2) ymm_w2_imag, ymm_w6_imag = YMMRegister(), ymm_G2_imag VMOVAPS(ymm_w2_imag, ymm_G2_imag) VFMADD231PS(ymm_w2_imag, ymm_H2_sub, ymm_sqrt2_over_2) VFMSUB231PS(ymm_w6_imag, ymm_H2_sub, ymm_sqrt2_over_2) # w[ n].real = G[ n].real - H[ n].real * cos((N-n)*pi/2N) - H[ n].imag * cos(n*pi/2N) # w[2N-n].real = G[ n].real + H[ n].real * cos((N-n)*pi/2N) + H[ n].imag * cos(n*pi/2N) # w[ n].imag = G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[2N-n].imag = -G[ n].imag + H[ n].real * cos(n*pi/2N) - H[ n].imag * cos((N-n)*pi/2N) # w[ N-n].real = G[N-n].real - H[N-n].real * cos(n*pi/2N) - H[N-n].imag * cos((N-n)*pi/2N) # w[ N+n].real = G[N-n].real + H[N-n].real * cos(n*pi/2N) + H[N-n].imag * cos((N-n)*pi/2N) # w[ N-n].imag = G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) # w[ N+n].imag = -G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N) ymm_cos_1pi_over_8, ymm_cos_3pi_over_8 = YMMRegister(), YMMRegister() VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1])) VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3])) VFNMADD231PS(ymm_w1_real, ymm_H1_imag, ymm_cos_1pi_over_8) VFMADD231PS(ymm_w7_real, ymm_H1_imag, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_w1_imag, ymm_H1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_w7_imag, ymm_H1_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_w3_real, ymm_H3_imag, ymm_cos_3pi_over_8) VFMADD231PS(ymm_w5_real, ymm_H3_imag, ymm_cos_3pi_over_8) VFNMADD231PS(ymm_w3_imag, ymm_H3_imag, ymm_cos_1pi_over_8) VFNMADD231PS(ymm_w5_imag, ymm_H3_imag, ymm_cos_1pi_over_8) data = [ LocalVariable(YMMRegister.size), YMMRegister(), ymm_w1_real, ymm_w1_imag, ymm_w2_real, ymm_w2_imag, ymm_w3_real, ymm_w3_imag, LocalVariable(YMMRegister.size), LocalVariable(YMMRegister.size), ymm_w5_real, ymm_w5_imag, ymm_w6_real, ymm_w6_imag, ymm_w7_real, ymm_w7_imag ] real, imag = data[0::2], data[1::2] # TODO: optimize ymm_w0_real, ymm_w0_imag = YMMRegister(), imag[0] VMOVUPS(ymm_w0_real, in_real[0]) VMOVUPS(ymm_w0_imag, in_imag[0]) VMULPS(ymm_w0_real, ymm_w0_real, Constant.float32x8(0.0625)) butterfly(ymm_w0_real, ymm_w0_imag, scale_b=Constant.float32x8(0.0625)) VMOVAPS(real[0], ymm_w0_real) # TODO: optimize ymm_w4_real, ymm_w4_imag = YMMRegister(), YMMRegister() VMOVUPS(ymm_w4_real, in_real[4]) VMOVUPS(ymm_w4_imag, in_imag[4]) VMULPS(ymm_w4_real, ymm_w4_real, Constant.float32x8(0.125)) VMULPS(ymm_w4_imag, ymm_w4_imag, Constant.float32x8(-0.125)) VMOVAPS(real[4], ymm_w4_real) VMOVAPS(imag[4], ymm_w4_imag) # Bit reversal for i in range(8): new_i = fft8_bitreverse(i) if new_i > i: real[i], real[new_i] = real[new_i], real[i] imag[i], imag[new_i] = imag[new_i], imag[i] data = interleave(real, imag) # 4x FFT2: butterfly for i, (data_lo, data_hi) in enumerate(zip(data[0:2] + data[4:6] + data[8:10] + data[12:14], data[2:4] + data[6:8] + data[10:12] + data[14:16])): butterfly(data_lo, data_hi) # 2x FFT4: multiplication by twiddle factors fft4_scale_b, fft4_negate_b = {}, {} fft8_scale_b, fft8_negate_b = {}, {} # w3.re, w3.im = -w3.im, w3.re # w7.re, w7.im = -w7.im, w7.re SWAP.REGISTERS(real[3], imag[3]) fft4_negate_b[id(real[3])] = True SWAP.REGISTERS(real[7], imag[7]) fft4_negate_b[id(real[7])] = True # 2x FFT4: butterfly for data_lo, data_hi in zip(data[0:4] + data[8:12], data[4:8] + data[12:16]): butterfly(data_lo, data_hi, negate_b=fft4_negate_b.get(id(data_hi), False)) # FFT8: multiplication by twiddle factors # w6.re, w6.im = -w6.im, w6.re SWAP.REGISTERS(real[6], imag[6]) fft8_negate_b[id(real[6])] = True # w5.re, w5.im = SQRT2_OVER_2 * (w5.re - w5.im), SQRT2_OVER_2 * (w5.re + w5.im) butterfly(real[5], imag[5], negate_b=True) fft8_scale_b[id(real[5])] = Constant.float32x8(sqrt2_over_2) fft8_scale_b[id(imag[5])] = Constant.float32x8(sqrt2_over_2) # w7.re, w7.im = -SQRT2_OVER_2 * (w7.re + w7.im), SQRT2_OVER_2 * (w7.re - w7.im) butterfly(real[7], imag[7]) fft8_scale_b[id(real[7])] = Constant.float32x8(sqrt2_over_2) fft8_negate_b[id(real[7])] = True fft8_scale_b[id(imag[7])] = Constant.float32x8(sqrt2_over_2) ymm_store_mask = YMMRegister() if store_mask: VMOVAPS(ymm_store_mask, store_mask) # FFT8: butterfly with Block() as store_data: for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])): row_lo = i row_hi = row_lo + 8 ymm_data_lo, ymm_data_hi = \ butterfly(data_lo, data_hi, scale_b=fft8_scale_b.get(id(data_hi)), negate_b=fft8_negate_b.get(id(data_hi), False), writeback=False) with Block() as store_data_lo: if reg_row_start: CMP(reg_row_start, row_lo) JA(store_data_lo.end) if reg_row_end: CMP(reg_row_end, row_lo) JBE(store_data_lo.end) elif reg_row_end: CMP(reg_row_end, row_lo) JBE(store_data.end) if store_mask: VMASKMOVPS([reg_t0], ymm_store_mask, ymm_data_lo) else: VMOVUPS([reg_t0], ymm_data_lo) if i + 1 != 8: ADD(reg_t0, reg_t_stride) with Block() as store_data_hi: if reg_row_start: CMP(reg_row_start, row_hi) JA(store_data_hi.end) if reg_row_end: CMP(reg_row_end, row_hi) JBE(store_data_hi.end) if store_mask: VMASKMOVPS([reg_t8], ymm_store_mask, ymm_data_hi) else: VMOVUPS([reg_t8], ymm_data_hi) if i + 1 != 8: ADD(reg_t8, reg_t_stride)
def ifft16_within_rows_preprocess(ymm_wr, ymm_wi, bit_reversal=False): assert isinstance(ymm_wr, (list, tuple)) and len(ymm_wr) == 2 and all(isinstance(reg, YMMRegister) for reg in ymm_wr) assert isinstance(ymm_wi, (list, tuple)) and len(ymm_wi) == 2 and all(isinstance(reg, YMMRegister) for reg in ymm_wi) # w0.re, w1.re, w2.re, w3.re, w4.re, w5.re, w6.re, w7.re = \ # = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, f8.re - f9.im, f10.re - f11.im, f12.re - f13.im, f14.re - f15.im # w8.re, w9.re, w10.re, w11.re, w12.re, w13.re, w14.re, w15.re = \ # = f0.im, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, f8.re + f9.im, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im # # w0.im, w1.im, w2.im, w3.im, w4.im, w5.im, w6.im, w7.im = \ # = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im # w8.im, w9.im, w10.im, w11.im, w12.im, w13.im, w14.im, w15.im = \ # = f1.im, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, f9.re - f8.im, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im # Step 1.A: # w0.re, w1.re, w2.re, w3.re, -, w13.re, w14.re, w15.re = \ # = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, -, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im # w0.im, w1.im, w2.im, w3.im, -, w13.im, w14.im, w15.im = \ # = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, -, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im ymm_shuffle_02461642 = YMMRegister() VMOVDQA(ymm_shuffle_02461642, Constant.uint32x8(0, 2, 4, 6, 1, 6, 4, 2)) ymm_shuffle_13570753 = YMMRegister() VMOVDQA(ymm_shuffle_13570753, Constant.uint32x8(1, 3, 5, 7, 0, 7, 5, 3)) ymm_fr_02461642, ymm_fi_13570753 = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_02461642, ymm_shuffle_02461642, ymm_wr[0]) VPERMPS(ymm_fi_13570753, ymm_shuffle_13570753, ymm_wi[0]) VFMADD231PS(ymm_fr_02461642, ymm_fi_13570753, Constant.float32x8(0.0, -1.0, -1.0, -1.0, 0.0, +1.0, +1.0, +1.0)) ymm_fr_13570753, ymm_fi_02461642 = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_13570753, ymm_shuffle_13570753, ymm_wr[0]) VPERMPS(ymm_fi_02461642, ymm_shuffle_02461642, ymm_wi[0]) VFMADD231PS(ymm_fr_13570753, ymm_fi_02461642, Constant.float32x8(0.0, +1.0, +1.0, +1.0, 0.0, -1.0, -1.0, -1.0)) ymm_wr_0123xDEF, ymm_wi_0123xDEF = ymm_fr_02461642, ymm_fr_13570753 # Step 1.B: # -, w9.re, w10.re, w11.re, w4.re, w5.re, w6.re, w7.re = \ # = -, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, r8.re - r9.im, r10.re - r11.im, r12.re - r13.im, r14.re - f15.im # -, w9.im, w10.im, w11.im, w4.im, w5.im, w6.im, w7.im = \ # = -, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im ymm_shuffle_06420246 = YMMRegister() VMOVDQA(ymm_shuffle_06420246, Constant.uint32x8(0, 6, 4, 2, 0, 2, 4, 6)) ymm_shuffle_17531357 = YMMRegister() VMOVDQA(ymm_shuffle_17531357, Constant.uint32x8(1, 7, 5, 3, 1, 3, 5, 7)) ymm_wr_xxxxCxxx, ymm_wi_xxxxCxxx = YMMRegister(), YMMRegister() ymm_wr_0123CDEF, ymm_wi_0123CDEF = YMMRegister(), YMMRegister() ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_8ECA8ACE, ymm_shuffle_06420246, ymm_wr[1]) VPERMPS(ymm_fi_9FDB9BDF, ymm_shuffle_17531357, ymm_wi[1]) VADDPS(ymm_wr_xxxxCxxx, ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF) VFMADD231PS(ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF, Constant.float32x8(0.0, +1.0, +1.0, +1.0, -1.0, -1.0, -1.0, -1.0)) VBLENDPS(ymm_wr_0123CDEF, ymm_wr_0123xDEF, ymm_wr_xxxxCxxx, 0b00010000) ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE = YMMRegister(), YMMRegister() VPERMPS(ymm_fr_9FDB9BDF, ymm_shuffle_17531357, ymm_wr[1]) VPERMPS(ymm_fi_8ECA8ACE, ymm_shuffle_06420246, ymm_wi[1]) VSUBPS(ymm_wi_xxxxCxxx, ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE) VFMADD231PS(ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE, Constant.float32x8(0.0, -1.0, -1.0, -1.0, +1.0, +1.0, +1.0, +1.0)) VBLENDPS(ymm_wi_0123CDEF, ymm_wi_0123xDEF, ymm_wi_xxxxCxxx, 0b00010000) ymm_wr_x9AB4567, ymm_wi_x9AB4567 = ymm_fr_8ECA8ACE, ymm_fr_9FDB9BDF ymm_wr_89AB4567, ymm_wi_89AB4567 = YMMRegister(), YMMRegister() VBLENDPS(ymm_wr_89AB4567, ymm_wr_x9AB4567, ymm_fi_02461642, 0b00000001) VBLENDPS(ymm_wi_89AB4567, ymm_wi_x9AB4567, ymm_fi_13570753, 0b00000001) ymm_wr_01234567, ymm_wr_89ABCDEF = YMMRegister(), YMMRegister() VBLENDPS(ymm_wr_01234567, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0xF0) VBLENDPS(ymm_wr_89ABCDEF, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0x0F) ymm_wi_01234567, ymm_wi_89ABCDEF = YMMRegister(), YMMRegister() VBLENDPS(ymm_wi_01234567, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0xF0) VBLENDPS(ymm_wi_89ABCDEF, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0x0F) SWAP.REGISTERS(ymm_wr[0], ymm_wr_01234567) SWAP.REGISTERS(ymm_wi[0], ymm_wi_01234567) SWAP.REGISTERS(ymm_wr[1], ymm_wr_89ABCDEF) SWAP.REGISTERS(ymm_wi[1], ymm_wi_89ABCDEF) if bit_reversal: # Bit reversal # w[0] = x0 x8 x4 x12 x2 x10 x6 x14 # w[1] = x1 x9 x5 x13 x3 x11 x7 x15 ymm_bit_reversal_mask = YMMRegister() VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 2, 4, 6, 1, 3, 5, 7)) for ymm in interleave(ymm_wr, ymm_wi): VPERMPS(ymm, ymm_bit_reversal_mask, ymm)
def train_mix(args, labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda, target_disb, emp_distb_u, pseudo_orig, pseudo_refine): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_x = AverageMeter() losses_u = AverageMeter() ws = AverageMeter() end = time.time() bar = Bar('Training', max=args.val_iteration) labeled_train_iter = iter(labeled_trainloader) unlabeled_train_iter = iter(unlabeled_trainloader) model.train() for batch_idx in range(args.val_iteration): try: inputs_x, targets_x, _ = labeled_train_iter.next() except: labeled_train_iter = iter(labeled_trainloader) inputs_x, targets_x, _ = labeled_train_iter.next() try: (inputs_u, inputs_u2), _, idx_u = unlabeled_train_iter.next() except: unlabeled_train_iter = iter(unlabeled_trainloader) (inputs_u, inputs_u2), _, idx_u = unlabeled_train_iter.next() # Measure data loading time data_time.update(time.time() - end) batch_size = inputs_x.size(0) # Transform label to one-hot targets_x = torch.zeros(batch_size, args.num_class).scatter_(1, targets_x.view(-1,1), 1) if use_cuda: inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True) inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda() # Generate the pseudo labels by aggregation and sharpening with torch.no_grad(): outputs_u, _ = model(inputs_u) outputs_u2, _ = model(inputs_u2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p ** (1 / args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) # Update the saved predictions with current one p = targets_u pseudo_orig[idx_u, :] = p.data.cpu() pseudo_orig_backup = pseudo_orig.clone() # Applying DARP if args.darp and epoch > args.warm: if batch_idx % args.num_iter == 0: # Iterative normalization targets_u, weights_u = estimate_pseudo(target_disb, pseudo_orig, args.num_class, args.alpha) scale_term = targets_u * weights_u.reshape(1, -1) pseudo_orig = (pseudo_orig * scale_term + 1e-6) \ / (pseudo_orig * scale_term + 1e-6).sum(dim=1, keepdim=True) if args.dataset == 'stl10' or args.dataset == 'cifar100': opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.3) else: opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.1) # Updated pseudo-labels are saved pseudo_refine = opt_res print(pseudo_refine.sum(dim=0)) print(target_disb) # Select targets_u = opt_res[idx_u].detach().cuda() pseudo_orig = pseudo_orig_backup else: # Using previously saved pseudo-labels targets_u = pseudo_refine[idx_u].cuda() # Mixup all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0) all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0) l = np.random.beta(args.mix_alpha, args.mix_alpha) l = max(l, 1-l) idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, batch_size)) mixed_input = interleave(mixed_input, batch_size) logits = [model(mixed_input[0])[0]] for input in mixed_input[1:]: logits.append(model(input)[0]) # put interleaved samples back logits = interleave(logits, batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) Lx, Lu, w = criterion(args, logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.val_iteration) Lu *= w loss = Lx + Lu # record loss losses.update(loss.item(), inputs_x.size(0)) losses_x.update(Lx.item(), inputs_x.size(0)) losses_u.update(Lu.item(), inputs_x.size(0)) ws.update(w, inputs_x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() ema_optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 'Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f}'.format( batch=batch_idx + 1, size=args.val_iteration, data=data_time.avg, bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, loss_x=losses_x.avg, loss_u=losses_u.avg, ) bar.next() bar.finish() return (losses.avg, losses_x.avg, losses_u.avg, emp_distb_u, pseudo_orig, pseudo_refine)
def train_remix(args, labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda, target_disb, emp_distb_u, pseudo_orig, pseudo_refine): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_x = AverageMeter() losses_u = AverageMeter() losses_r = AverageMeter() losses_e = AverageMeter() ws = AverageMeter() end = time.time() bar = Bar('Training', max=args.val_iteration) labeled_train_iter = iter(labeled_trainloader) unlabeled_train_iter = iter(unlabeled_trainloader) model.train() for batch_idx in range(args.val_iteration): try: inputs_x, targets_x, _ = labeled_train_iter.next() except: labeled_train_iter = iter(labeled_trainloader) inputs_x, targets_x, _ = labeled_train_iter.next() try: (inputs_u, inputs_u2, inputs_u3), _, idx_u = unlabeled_train_iter.next() except: unlabeled_train_iter = iter(unlabeled_trainloader) (inputs_u, inputs_u2, inputs_u3), _, idx_u = unlabeled_train_iter.next() # Measure data loading time data_time.update(time.time() - end) batch_size = inputs_x.size(0) # Transform label to one-hot targets_x = torch.zeros(batch_size, args.num_class).scatter_(1, targets_x.view(-1,1), 1) if use_cuda: inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True) inputs_u, inputs_u2, inputs_u3 = inputs_u.cuda(), inputs_u2.cuda(), inputs_u3.cuda() # Rotate images temp = [] targets_r = torch.randint(0, 4, (inputs_u2.size(0),)).long() for i in range(inputs_u2.size(0)): inputs_rot = torch.rot90(inputs_u2[i], targets_r[i], [1, 2]).reshape(1, 3, 32, 32) temp.append(inputs_rot) inputs_r = torch.cat(temp, 0) targets_r = torch.zeros(batch_size, 4).scatter_(1, targets_r.view(-1, 1), 1) inputs_r, targets_r = inputs_r.cuda(), targets_r.cuda(non_blocking=True) # Generate the pseudo labels with torch.no_grad(): outputs_u, _ = model(inputs_u) p = torch.softmax(outputs_u, dim=1) # Tracking the empirical distribution on the unlabeled samples (ReMixMatch) real_batch_idx = batch_idx + epoch * args.val_iteration if real_batch_idx == 0: emp_distb_u = p.mean(0, keepdim=True) elif real_batch_idx // 128 == 0: emp_distb_u = torch.cat([emp_distb_u, p.mean(0, keepdim=True)], 0) else: emp_distb_u = emp_distb_u[-127:] emp_distb_u = torch.cat([emp_distb_u, p.mean(0, keepdim=True)], 0) # Distribution alignment if args.align: pa = p * (target_disb.cuda() + 1e-6) / (emp_distb_u.mean(0).cuda() + 1e-6) p = pa / pa.sum(dim=1, keepdim=True) # Temperature scailing pt = p ** (1 / args.T) targets_u = (pt / pt.sum(dim=1, keepdim=True)).detach() # Update the saved predictions with current one p = targets_u pseudo_orig[idx_u, :] = p.data.cpu() pseudo_orig_backup = pseudo_orig.clone() # Applying DARP if args.darp and epoch > args.warm: if batch_idx % args.num_iter == 0: # Iterative normalization targets_u, weights_u = estimate_pseudo(target_disb, pseudo_orig, args.num_class, args.alpha) scale_term = targets_u * weights_u.reshape(1, -1) pseudo_orig = (pseudo_orig * scale_term + 1e-6) \ / (pseudo_orig * scale_term + 1e-6).sum(dim=1, keepdim=True) if args.dataset == 'stl10' or args.dataset == 'cifar100': opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.3) else: opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.1) # Updated pseudo-labels are saved pseudo_refine = opt_res # Select targets_u = opt_res[idx_u].detach().cuda() pseudo_orig = pseudo_orig_backup else: # Using previously saved pseudo-labels targets_u = pseudo_refine[idx_u].cuda() # Mixup all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2, inputs_u3], dim=0) all_targets = torch.cat([targets_x, targets_u, targets_u, targets_u], dim=0) l = np.random.beta(args.mix_alpha, args.mix_alpha) l = max(l, 1-l) idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, batch_size)) mixed_input = interleave(mixed_input, batch_size) logits = [model(mixed_input[0])[0]] for input in mixed_input[1:]: logits.append(model(input)[0]) # put interleaved samples back logits = interleave(logits, batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) Lx, Lu, w = criterion(args, logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.val_iteration) _, logits_r = model(inputs_r) Lu *= w Lr = -1 * torch.mean(torch.sum(F.log_softmax(logits_r, dim=1) * targets_r, dim=1)) Lr *= args.w_rot # Entropy minimization for unlabeled samples (strong augmented) outputs_u2, _ = model(inputs_u2) Le = -1 * torch.mean(torch.sum(F.log_softmax(outputs_u2, dim=1) * targets_u, dim=1)) Le *= args.w_ent * linear_rampup(epoch+batch_idx/args.val_iteration, args.epochs) loss = Lx + Lu + Lr + Le # record loss losses.update(loss.item(), inputs_x.size(0)) losses_x.update(Lx.item(), inputs_x.size(0)) losses_u.update(Lu.item(), inputs_x.size(0)) losses_r.update(Lr.item(), inputs_x.size(0)) losses_e.update(Le.item(), inputs_x.size(0)) ws.update(w, inputs_x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() ema_optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 'Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | Loss_r: {loss_r:.4f} | ' \ 'Loss_e: {loss_e:.4f}'.format( batch=batch_idx + 1, size=args.val_iteration, data=data_time.avg, bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, loss_x=losses_x.avg, loss_u=losses_u.avg, loss_r=losses_r.avg, loss_e=losses_e.avg, ) bar.next() bar.finish() return (losses.avg, losses_x.avg, losses_u.avg, emp_distb_u, pseudo_orig, pseudo_refine)
def amp_out(time): """ A double-decay diminuendo. """ dest, times, curves = split(amp_in(time), 3) return interleave(dest[-1::-1], [0.] + times[-1:0:-1], curves)
def cgemm_loop(ymm_c, reg_a, reg_b, reg_k, step_k, loop, conjugate_b, mixed_columns): ymm_c_real, ymm_c_imag = ymm_c assert isinstance(reg_k, GeneralPurposeRegister64) assert isinstance(step_k, int) and step_k >= 1 assert isinstance(mixed_columns, bool) assert isinstance(loop, Loop) assert isinstance(ymm_c_real, list) and isinstance(ymm_c_imag, list) and len(ymm_c_real) == len(ymm_c_imag) mr = len(ymm_c_real) assert isinstance(ymm_c_real[0], list) nr = len(ymm_c_real[0]) assert all(isinstance(ymm_c_real_m, list) and len(ymm_c_real_m) == nr for ymm_c_real_m in ymm_c_real) assert all(isinstance(ymm_c_imag_m, list) and len(ymm_c_imag_m) == nr for ymm_c_imag_m in ymm_c_imag) step_a, step_b = mr * step_k * YMMRegister.size * 2, nr * step_k * YMMRegister.size * 2 disp_shift_a = 0 if step_a <= 128 else -128 disp_shift_b = 0 if step_b <= 128 else -128 ymm_a_real, ymm_a_imag = tuple([YMMRegister() for m in range(mr)] for c in range(2)) ymm_b_real, ymm_b_imag = tuple([YMMRegister() for n in range(nr)] for c in range(2)) use_disp_shift = False if step_k > 1: if disp_shift_a != 0: SUB(reg_a, disp_shift_a) if disp_shift_b != 0: SUB(reg_b, disp_shift_b) SUB(reg_k, step_k) JB(loop.end) with loop: for k in range(step_k): for i, ymm_a in enumerate(interleave(ymm_a_real, ymm_a_imag)): VMOVAPS(ymm_a, [reg_a + (i + 2*mr*k) * YMMRegister.size + disp_shift_a]) for i, ymm_b in enumerate(interleave(ymm_b_real, ymm_b_imag)): VMOVAPS(ymm_b, [reg_b + (i + 2*nr*k) * YMMRegister.size + disp_shift_b]) for n in range(nr): for m in range(mr): VFMADD231PS(ymm_c_real[m][n], ymm_a_real[m], ymm_b_real[n]) if mixed_columns: VBLENDPS(ymm_b_real[n], ymm_b_real[n], ymm_b_imag[n], 0b00000011) for m in range(mr): VFMADD231PS(ymm_c_imag[m][n], ymm_a_imag[m], ymm_b_real[n]) if nr > 1 and mixed_columns: zero_columns01_mask = YMMRegister() VMOVAPS(zero_columns01_mask, Constant.uint32x8(0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)) else: zero_columns01_mask = Constant.uint32x8(0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF) # if step_k > 1: # PREFETCHNTA([reg_a + 640]) for n in range(nr): if mixed_columns: VANDPS(ymm_b_imag[n], ymm_b_imag[n], zero_columns01_mask) for m in range(mr): if conjugate_b: VFMADD231PS(ymm_c_real[m][n], ymm_a_imag[m], ymm_b_imag[n]) VFNMADD231PS(ymm_c_imag[m][n], ymm_a_real[m], ymm_b_imag[n]) else: VFNMADD231PS(ymm_c_real[m][n], ymm_a_imag[m], ymm_b_imag[n]) VFMADD231PS(ymm_c_imag[m][n], ymm_a_real[m], ymm_b_imag[n]) SUB(reg_a, -step_a) SUB(reg_b, -step_b) if step_k > 1: SUB(reg_k, step_k) JAE(loop.begin) else: DEC(reg_k) JNE(loop.begin) if step_k > 1: if disp_shift_a: ADD(reg_a, disp_shift_a) if disp_shift_b: ADD(reg_b, disp_shift_b) ADD(reg_k, step_k)