예제 #1
0
def env_uniform(etime):
    """
    Uniform cresendo or diminuendo;
    """
    dests = [0., random() * .1 + .05, random() * .25 + .75, 0.]
    times = [0., 30., etime * 1000. - 60., 30.]
    curves = [0.] + window([.85, .85]) + [0.]
    if random() < .5:
        return interleave(dests, times, curves)
    return interleave(dests[-1::-1], [0.] + times[-1:0:-1], curves)
예제 #2
0
def env_jagged(etime):
    """
    Jagged envelope:

    Picks at most 4, at least 2 events for each second of
    the event time. Each event is categorized by a random 
    amplitude value and time distribution w/ a curve.
    """
    dests = [0.]
    times = [0.]
    time_error = 0.
    loctime = etime
    while loctime > 0:
        if loctime > 1.:
            correct = 1
        else:
            correct = loctime
        numevents = int((random() * 3. + 2.) * loctime)
        dests += [random() for i in range(numevents)]
        time_add = [random() / numevents for i in range(numevents)]
        next_error = 1. - sum(time_add)
        if len(time_add) > 0:
            time_add[0] += time_error
        times += time_add
        time_error = next_error
        loctime -= 1.
    dests[len(dests) - 1] = 0.
    times[len(times) - 1] += etime - sum(times)
    times = [1000. * t for t in times]
    curves = [0.] + [window(.75) for i in range(1, len(dests) - 1)] + [0.]
    return interleave(dests, times, curves)
예제 #3
0
def pitch_unidir(time, correct):
    """
    Unidirectional pitch-shift into a specified pitch.

    Starts at most 3 semitones away from the desired pitch.
    """
    trans = ratio(window(1.) + correct)
    origin = ratio(correct)
    return interleave([trans, origin], [0., time * 1000], [1, window(.75)])
예제 #4
0
def amp_in(time):
    """
    A double-attack crescendo.
    """
    ms = lambda t: _ms(t, time)
    first_time = random() * .35 + .2
    second_time = random() * .3 + .2
    decay_time = 1 - first_time - second_time
    first_val = random() * .4 + .4
    dest = [0., first_val, 1., 0.]
    times = [0., ms(first_time), ms(second_time), ms(decay_time)]
    curves = [1.] + window([.75, .75, .75])
    return interleave(dest, times, curves)
예제 #5
0
파일: engine_perc.py 프로젝트: ilzxc/thirst
def amp_env(time):
    """
    A simple AHD envelope.

    20 ms ramp up, hold for most of the time, 20 ms 
    ramp down.
    """
    ms = lambda t: _ms(t, time)
    attack = decay = 20  # ms
    hold = ms(time) - 40  # ms
    sustain = random() * 0.5 + 0.5
    dest = [0.0, sustain, sustain, 0.0]
    times = [0.0, attack, hold, decay]
    return interleave(dest, times)
예제 #6
0
def pitch_peak(time, correct):
    """
    A triangle-wave-like pitch shift.

    Moves up / down from no-transposition and returns to normal.
    Guaranteed to reach its destination within the first 75% of
    the time requested.
    """
    ms = lambda t: _ms(t, time)
    trans = ratio(window(1.) + correct)
    origin = ratio(correct)
    attack_time = random() * .75 + .2
    decay_time = 1. - attack_time
    dest = [origin, trans, origin]
    times = [0., ms(attack_time), ms(decay_time)]
    curves = [1.] + window([.75, .75])
    return interleave(dest, times, curves)
예제 #7
0
def amp_in_out(time):
    """
    An AHD envelope.

    Attack (in the first 75 percent of the time)
    Hold @ sustain value (20 percent of the time)
    Sustain value is in between .5 and 1.
    Decay to 0. for the time remaining
    """
    ms = lambda t: _ms(t, time)
    attack_time = random() * .55 + .2
    sustain_val = random() * .5 + .5
    hold_time = random() * .2
    decay_time = 1. - attack_time - hold_time
    dest = [0., sustain_val, sustain_val, 0.]
    times = [0., ms(attack_time), ms(hold_time), ms(decay_time)]
    curves = [1., window(.75), 1., window(.75)]
    return interleave(dest, times, curves)
예제 #8
0
def pitch_cross(time, correct):
    """
    Shifts up then down (or down then up), before returning to
    the original pitch of the sample.
    """
    ms = lambda t: _ms(t, time)
    _ts = lambda: random() * .33 + .1
    shifts = window([1., 1.])
    if (shifts[0] * shifts[1]) > 0:
        shifts[one([0, 1])] *= -1
    trans = []
    for s in shifts:
        trans.append(s + correct)
    trans = ratio(trans)
    first = _ts()
    second = 1 - _ts() - first
    decay = 1. - (first + second)
    origin = ratio(correct)
    dest = [origin] + trans + [origin]
    times = [0., ms(first), ms(second), ms(decay)]
    curves = [1.] + window([.75, .75, .75])
    return interleave(dest, times, curves)
def ifft16_within_rows_preprocess(ymm_wr, ymm_wi, bit_reversal=False):
    assert isinstance(ymm_wr, (list, tuple)) and len(ymm_wr) == 2 and all(
        isinstance(reg, YMMRegister) for reg in ymm_wr)
    assert isinstance(ymm_wi, (list, tuple)) and len(ymm_wi) == 2 and all(
        isinstance(reg, YMMRegister) for reg in ymm_wi)

    # w0.re, w1.re,  w2.re,  w3.re,  w4.re,  w5.re,  w6.re,  w7.re = \
    #    = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, f8.re - f9.im, f10.re - f11.im, f12.re - f13.im, f14.re - f15.im
    # w8.re, w9.re, w10.re, w11.re, w12.re, w13.re, w14.re, w15.re = \
    #    = f0.im, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, f8.re + f9.im, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im
    #
    # w0.im, w1.im,  w2.im,  w3.im,  w4.im,  w5.im,  w6.im,  w7.im = \
    #    = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im
    # w8.im, w9.im, w10.im, w11.im, w12.im, w13.im, w14.im, w15.im = \
    #    = f1.im, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, f9.re - f8.im, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im

    # Step 1.A:
    #   w0.re, w1.re, w2.re, w3.re, -, w13.re, w14.re, w15.re = \
    #      = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, -, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im
    #   w0.im, w1.im, w2.im, w3.im, -, w13.im, w14.im, w15.im = \
    #      = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, -, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im

    ymm_shuffle_02461642 = YMMRegister()
    VMOVDQA(ymm_shuffle_02461642, Constant.uint32x8(0, 2, 4, 6, 1, 6, 4, 2))
    ymm_shuffle_13570753 = YMMRegister()
    VMOVDQA(ymm_shuffle_13570753, Constant.uint32x8(1, 3, 5, 7, 0, 7, 5, 3))

    ymm_fr_02461642, ymm_fi_13570753 = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_02461642, ymm_shuffle_02461642, ymm_wr[0])
    VPERMPS(ymm_fi_13570753, ymm_shuffle_13570753, ymm_wi[0])
    VFMADD231PS(
        ymm_fr_02461642, ymm_fi_13570753,
        Constant.float32x8(0.0, -1.0, -1.0, -1.0, 0.0, +1.0, +1.0, +1.0))

    ymm_fr_13570753, ymm_fi_02461642 = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_13570753, ymm_shuffle_13570753, ymm_wr[0])
    VPERMPS(ymm_fi_02461642, ymm_shuffle_02461642, ymm_wi[0])
    VFMADD231PS(
        ymm_fr_13570753, ymm_fi_02461642,
        Constant.float32x8(0.0, +1.0, +1.0, +1.0, 0.0, -1.0, -1.0, -1.0))

    ymm_wr_0123xDEF, ymm_wi_0123xDEF = ymm_fr_02461642, ymm_fr_13570753

    # Step 1.B:
    #   -, w9.re, w10.re, w11.re, w4.re, w5.re, w6.re, w7.re = \
    #      = -, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, r8.re - r9.im, r10.re - r11.im, r12.re - r13.im, r14.re - f15.im
    #   -, w9.im, w10.im, w11.im, w4.im, w5.im, w6.im, w7.im = \
    #      = -, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im

    ymm_shuffle_06420246 = YMMRegister()
    VMOVDQA(ymm_shuffle_06420246, Constant.uint32x8(0, 6, 4, 2, 0, 2, 4, 6))
    ymm_shuffle_17531357 = YMMRegister()
    VMOVDQA(ymm_shuffle_17531357, Constant.uint32x8(1, 7, 5, 3, 1, 3, 5, 7))

    ymm_wr_xxxxCxxx, ymm_wi_xxxxCxxx = YMMRegister(), YMMRegister()
    ymm_wr_0123CDEF, ymm_wi_0123CDEF = YMMRegister(), YMMRegister()

    ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_8ECA8ACE, ymm_shuffle_06420246, ymm_wr[1])
    VPERMPS(ymm_fi_9FDB9BDF, ymm_shuffle_17531357, ymm_wi[1])
    VADDPS(ymm_wr_xxxxCxxx, ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF)
    VFMADD231PS(
        ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF,
        Constant.float32x8(0.0, +1.0, +1.0, +1.0, -1.0, -1.0, -1.0, -1.0))
    VBLENDPS(ymm_wr_0123CDEF, ymm_wr_0123xDEF, ymm_wr_xxxxCxxx, 0b00010000)

    ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_9FDB9BDF, ymm_shuffle_17531357, ymm_wr[1])
    VPERMPS(ymm_fi_8ECA8ACE, ymm_shuffle_06420246, ymm_wi[1])
    VSUBPS(ymm_wi_xxxxCxxx, ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE)
    VFMADD231PS(
        ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE,
        Constant.float32x8(0.0, -1.0, -1.0, -1.0, +1.0, +1.0, +1.0, +1.0))
    VBLENDPS(ymm_wi_0123CDEF, ymm_wi_0123xDEF, ymm_wi_xxxxCxxx, 0b00010000)

    ymm_wr_x9AB4567, ymm_wi_x9AB4567 = ymm_fr_8ECA8ACE, ymm_fr_9FDB9BDF
    ymm_wr_89AB4567, ymm_wi_89AB4567 = YMMRegister(), YMMRegister()
    VBLENDPS(ymm_wr_89AB4567, ymm_wr_x9AB4567, ymm_fi_02461642, 0b00000001)
    VBLENDPS(ymm_wi_89AB4567, ymm_wi_x9AB4567, ymm_fi_13570753, 0b00000001)

    ymm_wr_01234567, ymm_wr_89ABCDEF = YMMRegister(), YMMRegister()
    VBLENDPS(ymm_wr_01234567, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0xF0)
    VBLENDPS(ymm_wr_89ABCDEF, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0x0F)

    ymm_wi_01234567, ymm_wi_89ABCDEF = YMMRegister(), YMMRegister()
    VBLENDPS(ymm_wi_01234567, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0xF0)
    VBLENDPS(ymm_wi_89ABCDEF, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0x0F)

    SWAP.REGISTERS(ymm_wr[0], ymm_wr_01234567)
    SWAP.REGISTERS(ymm_wi[0], ymm_wi_01234567)
    SWAP.REGISTERS(ymm_wr[1], ymm_wr_89ABCDEF)
    SWAP.REGISTERS(ymm_wi[1], ymm_wi_89ABCDEF)

    if bit_reversal:
        # Bit reversal
        # w[0] = x0 x8 x4 x12 x2 x10 x6 x14
        # w[1] = x1 x9 x5 x13 x3 x11 x7 x15
        ymm_bit_reversal_mask = YMMRegister()
        VMOVDQA(ymm_bit_reversal_mask,
                Constant.uint32x8(0, 2, 4, 6, 1, 3, 5, 7))
        for ymm in interleave(ymm_wr, ymm_wi):
            VPERMPS(ymm, ymm_bit_reversal_mask, ymm)
예제 #10
0
def forward_vfft(reg_t0, reg_t8, reg_t_stride, data_out, reg_row_start=None, reg_row_end=None, ymm_load_mask=None):
    assert isinstance(reg_t0, GeneralPurposeRegister64)
    assert isinstance(reg_t8, GeneralPurposeRegister64)
    assert isinstance(reg_t_stride, GeneralPurposeRegister64)
    assert isinstance(data_out, list) and len(data_out) == 16
    assert ymm_load_mask is None or isinstance(ymm_load_mask, YMMRegister)

    out_real, out_imag = data_out[0::2], data_out[1::2]

    real, imag = [YMMRegister() for _ in range(8)], [YMMRegister() for _ in range(8)]
    imag[0] = LocalVariable(YMMRegister.size)
    imag[4] = LocalVariable(YMMRegister.size)
    data = interleave(real, imag)

    for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])):
        row_lo = i
        row_hi = row_lo + 8

        ymm_data_lo, ymm_data_hi = data_lo, data_hi
        if isinstance(data_lo, LocalVariable):
            ymm_data_lo = YMMRegister()
        if isinstance(data_hi, LocalVariable):
            ymm_data_hi = YMMRegister()

        VXORPS(ymm_data_lo, ymm_data_lo, ymm_data_lo)
        skip_data_lo = Label()
        if reg_row_start:
            CMP(reg_row_start, row_lo)
            JA(skip_data_lo)
        if reg_row_end:
            CMP(reg_row_end, row_lo)
            JBE(skip_data_lo)
        if ymm_load_mask is None:
            VMOVUPS(ymm_data_lo, [reg_t0])
        else:
            VMASKMOVPS(ymm_data_lo, ymm_load_mask, [reg_t0])
        if i + 1 != 8:
            ADD(reg_t0, reg_t_stride)
        LABEL(skip_data_lo)

        VMOVAPS(ymm_data_hi, ymm_data_lo)
        skip_data_hi = Label()
        if reg_row_start:
            CMP(reg_row_start, row_hi)
            JA(skip_data_hi)
        if reg_row_end:
            CMP(reg_row_end, row_hi)
            JBE(skip_data_hi)
        if ymm_load_mask is None:
            VMOVUPS(ymm_data_hi, [reg_t8])
            butterfly(ymm_data_lo, ymm_data_hi)
        else:
            ymm_temp_hi = YMMRegister()
            VMASKMOVPS(ymm_temp_hi, ymm_load_mask, [reg_t8])
            VSUBPS(ymm_data_hi, ymm_data_lo, ymm_temp_hi)
            VADDPS(ymm_data_lo, ymm_data_lo, ymm_temp_hi)
        if i + 1 != 8:
            ADD(reg_t8, reg_t_stride)
        LABEL(skip_data_hi)

        if isinstance(data_lo, LocalVariable):
            VMOVAPS(data_lo, ymm_data_lo)
        if isinstance(data_hi, LocalVariable):
            VMOVAPS(data_hi, ymm_data_hi)

    # FFT8: multiplication by twiddle factors
    fft4_scale_b, fft4_negate_b = {}, {}
    fft2_scale_b, fft2_negate_b = {}, {}

    # w6.re, w6.im = w6.im, -w6.re
    SWAP.REGISTERS(real[6], imag[6])
    fft4_negate_b[id(imag[6])] = True

    # w5.re, w5.im =  SQRT2_OVER_2 * (w5.re + w5.im),  SQRT2_OVER_2 * (w5.im - w5.re)
    butterfly(imag[5], real[5])
    SWAP.REGISTERS(real[5], imag[5])

    # w7.re, w7.im = -SQRT2_OVER_2 * (w7.re - w7.im), -SQRT2_OVER_2 * (w7.re + w7.im)
    butterfly(real[7], imag[7], negate_b=True)
    fft4_negate_b[id(real[7])] = True
    fft4_negate_b[id(imag[7])] = True

    # Propogate multiplication by sqrt2_over_2 until the last butterfly in FFT2
    ymm_sqrt2_over_2 = YMMRegister()
    fft2_scale_b[id(real[5])] = ymm_sqrt2_over_2
    fft2_scale_b[id(imag[5])] = ymm_sqrt2_over_2
    fft2_scale_b[id(real[7])] = ymm_sqrt2_over_2
    fft2_scale_b[id(imag[7])] = ymm_sqrt2_over_2

    # 2x FFT4: butterfly
    for data_lo, data_hi in zip(data[0:4] + data[8:12], data[4:8] + data[12:16]):
        butterfly(data_lo, data_hi, negate_b=fft4_negate_b.get(id(data_hi), False), scale_b=fft4_scale_b.get(id(data_hi)))

    # 2x FFT4: multiplication by twiddle factors

    # w3.re, w3.im = w3.im, -w3.re
    # w7.re, w7.im = w7.im, -w7.re
    SWAP.REGISTERS(real[3], imag[3])
    SWAP.REGISTERS(real[7], imag[7])
    fft2_negate_b[id(imag[3])] = True
    fft2_negate_b[id(imag[7])] = True

    # 4x FFT2: butterfly

    # Process the first two elements separately
    ymm_real0, ymm_real1 = butterfly(real[0], real[1], writeback=False)
    store_ymm_result(out_real[4], ymm_real1) # bit-reversal: 1->4
    ymm_imag0, ymm_imag1 = butterfly(imag[0], imag[1], negate_out_b=True, writeback=False)
    store_ymm_result(out_imag[4], ymm_imag1) # bit-reversal: 1->4

    VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2))
    for i, (data_lo, data_hi) in enumerate(zip(data[4:6] + data[8:10] + data[12:14], data[6:8] + data[10:12] + data[14:16])):
        butterfly(data_lo, data_hi,
            negate_b=fft2_negate_b.get(id(data_hi), False), scale_b=fft2_scale_b.get(id(data_hi)))

    butterfly(ymm_real0, ymm_imag0)
    store_ymm_result(out_real[0], ymm_real0)
    store_ymm_result(out_imag[0], ymm_imag0)

    # Bit reversal
    for i in range(8):
        new_i = fft8_bitreverse(i)
        if new_i > i:
            real[i], real[new_i] = real[new_i], real[i]
            imag[i], imag[new_i] = imag[new_i], imag[i]
    data = interleave(real, imag)

    ymm_two_g2_real, ymm_two_g2_imag = YMMRegister(), YMMRegister()
    ymm_two_h2_real, ymm_two_h2_imag = YMMRegister(), YMMRegister()

    VADDPS(ymm_two_g2_real, real[2], real[6])
    VSUBPS(ymm_two_h2_imag, real[6], real[2])

    VSUBPS(ymm_two_g2_imag, imag[2], imag[6])
    VADDPS(ymm_two_h2_real, imag[2], imag[6])

    ymm_two_g1_real, ymm_two_g1_imag = YMMRegister(), YMMRegister()
    ymm_two_h1_real, ymm_two_h1_imag = YMMRegister(), YMMRegister()
    ymm_real1 = load_ymm_variable(real[1])
    VADDPS(ymm_two_g1_real, ymm_real1, real[7])
    VSUBPS(ymm_two_h1_imag, real[7], ymm_real1)

    ymm_imag1 = load_ymm_variable(imag[1])
    VSUBPS(ymm_two_g1_imag, ymm_imag1, imag[7])
    VADDPS(ymm_two_h1_real, ymm_imag1, imag[7])

    ymm_two_h2_add, ymm_two_h2_sub = YMMRegister(), YMMRegister()
    VADDPS(ymm_two_h2_add, ymm_two_h2_real, ymm_two_h2_imag)
    VSUBPS(ymm_two_h2_sub, ymm_two_h2_imag, ymm_two_h2_real)

    ymm_two_g3_real, ymm_two_g3_imag = YMMRegister(), YMMRegister()
    ymm_two_h3_real, ymm_two_h3_imag = YMMRegister(), YMMRegister()
    VADDPS(ymm_two_g3_real, real[3], real[5])
    VSUBPS(ymm_two_h3_imag, real[5], real[3])

    VSUBPS(ymm_two_g3_imag, imag[3], imag[5])
    VADDPS(ymm_two_h3_real, imag[3], imag[5])

    # const float two_w2_real = two_g2_real + SQRT2_OVER_2 * (two_h2_real + two_h2_imag);
    # const float two_w2_imag = two_g2_imag + SQRT2_OVER_2 * (two_h2_imag - two_h2_real);
    # const float two_w6_real = two_g2_real - SQRT2_OVER_2 * (two_h2_real + two_h2_imag);
    # const float two_w6_imag = -two_g2_imag + SQRT2_OVER_2 * (two_h2_imag - two_h2_real);

    ymm_sqrt2_over_2 = YMMRegister()
    VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2))

    ymm_two_w2_real, ymm_two_w6_real = YMMRegister(), ymm_two_g2_real
    VMOVAPS(ymm_two_w2_real, ymm_two_g2_real)
    VFMADD231PS(ymm_two_w2_real, ymm_two_h2_add, ymm_sqrt2_over_2)
    VFNMADD231PS(ymm_two_w6_real, ymm_two_h2_add, ymm_sqrt2_over_2)

    ymm_two_w2_imag, ymm_two_w6_imag = YMMRegister(), ymm_two_g2_imag
    VMOVAPS(ymm_two_w2_imag, ymm_two_g2_imag)
    VFMADD231PS(ymm_two_w2_imag, ymm_two_h2_sub, ymm_sqrt2_over_2)
    VFMSUB231PS(ymm_two_w6_imag, ymm_two_h2_sub, ymm_sqrt2_over_2)

    ymm_half = YMMRegister()
    VMOVAPS(ymm_half, Constant.float32x8(0.5))

    VMULPS(ymm_two_w2_real, ymm_two_w2_real, ymm_half)
    store_ymm_result(out_real[2], ymm_two_w2_real)
    VMULPS(ymm_two_w6_real, ymm_two_w6_real, ymm_half)
    store_ymm_result(out_real[6], ymm_two_w6_real)

    VMULPS(ymm_two_w2_imag, ymm_two_w2_imag, ymm_half)
    store_ymm_result(out_imag[2], ymm_two_w2_imag)
    VMULPS(ymm_two_w6_imag, ymm_two_w6_imag, ymm_half)
    store_ymm_result(out_imag[6], ymm_two_w6_imag)

    # const float two_w1_real = two_g1_real + two_h1_real * COS_1PI_OVER_8 + two_h1_imag * COS_3PI_OVER_8;
    # const float two_w1_imag = two_g1_imag + two_h1_imag * COS_1PI_OVER_8 - two_h1_real * COS_3PI_OVER_8;
    # const float two_w7_real = two_g1_real - two_h1_real * COS_1PI_OVER_8 - two_h1_imag * COS_3PI_OVER_8;
    # const float two_w7_imag = -two_g1_imag + two_h1_imag * COS_1PI_OVER_8 - two_h1_real * COS_3PI_OVER_8;

    # const float two_w3_real = two_g3_real + two_h3_real * COS_3PI_OVER_8 + two_h3_imag * COS_1PI_OVER_8;
    # const float two_w3_imag = two_g3_imag + two_h3_imag * COS_3PI_OVER_8 - two_h3_real * COS_1PI_OVER_8;
    # const float two_w5_real = two_g3_real - two_h3_real * COS_3PI_OVER_8 - two_h3_imag * COS_1PI_OVER_8;
    # const float two_w5_imag = -two_g3_imag + two_h3_imag * COS_3PI_OVER_8 - two_h3_real * COS_1PI_OVER_8;

    ymm_cos_1pi_over_8 = YMMRegister()
    VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1]))

    ymm_two_w1_real, ymm_two_w7_real = YMMRegister(), ymm_two_g1_real
    VMOVAPS(ymm_two_w1_real, ymm_two_g1_real)
    VFMADD231PS(ymm_two_w1_real, ymm_two_h1_real, ymm_cos_1pi_over_8)
    VFNMADD231PS(ymm_two_w7_real, ymm_two_h1_real, ymm_cos_1pi_over_8)

    ymm_two_w1_imag, ymm_two_w7_imag = YMMRegister(), ymm_two_g1_imag
    VMOVAPS(ymm_two_w1_imag, ymm_two_g1_imag)
    VFMADD231PS(ymm_two_w1_imag, ymm_two_h1_imag, ymm_cos_1pi_over_8)
    VFMSUB231PS(ymm_two_w7_imag, ymm_two_h1_imag, ymm_cos_1pi_over_8)

    ymm_two_w3_real, ymm_two_w5_real = YMMRegister(), ymm_two_g3_real
    VMOVAPS(ymm_two_w3_real, ymm_two_g3_real)
    VFMADD231PS(ymm_two_w3_real, ymm_two_h3_imag, ymm_cos_1pi_over_8)
    VFNMADD231PS(ymm_two_w5_real, ymm_two_h3_imag, ymm_cos_1pi_over_8)

    ymm_two_w3_imag, ymm_two_w5_imag = YMMRegister(), ymm_two_g3_imag
    VMOVAPS(ymm_two_w3_imag, ymm_two_g3_imag)
    VFNMADD231PS(ymm_two_w3_imag, ymm_two_h3_real, ymm_cos_1pi_over_8)
    VFNMSUB231PS(ymm_two_w5_imag, ymm_two_h3_real, ymm_cos_1pi_over_8)

    ymm_cos_3pi_over_8 = YMMRegister()
    VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3]))

    VFMADD231PS(ymm_two_w1_real, ymm_two_h1_imag, ymm_cos_3pi_over_8)
    VFNMADD231PS(ymm_two_w7_real, ymm_two_h1_imag, ymm_cos_3pi_over_8)

    VFNMADD231PS(ymm_two_w1_imag, ymm_two_h1_real, ymm_cos_3pi_over_8)
    VFNMADD231PS(ymm_two_w7_imag, ymm_two_h1_real, ymm_cos_3pi_over_8)

    VFMADD231PS(ymm_two_w3_real, ymm_two_h3_real, ymm_cos_3pi_over_8)
    VFNMADD231PS(ymm_two_w5_real, ymm_two_h3_real, ymm_cos_3pi_over_8)

    VFMADD231PS(ymm_two_w3_imag, ymm_two_h3_imag, ymm_cos_3pi_over_8)
    VFMADD231PS(ymm_two_w5_imag, ymm_two_h3_imag, ymm_cos_3pi_over_8)

    ymm_half = YMMRegister()
    VMOVAPS(ymm_half, Constant.float32x8(0.5))

    VMULPS(ymm_two_w1_real, ymm_two_w1_real, ymm_half)
    store_ymm_result(out_real[1], ymm_two_w1_real)
    VMULPS(ymm_two_w7_real, ymm_two_w7_real, ymm_half)
    store_ymm_result(out_real[7], ymm_two_w7_real)

    VMULPS(ymm_two_w1_imag, ymm_two_w1_imag, ymm_half)
    store_ymm_result(out_imag[1], ymm_two_w1_imag)
    VMULPS(ymm_two_w7_imag, ymm_two_w7_imag, ymm_half)
    store_ymm_result(out_imag[7], ymm_two_w7_imag)

    VMULPS(ymm_two_w3_real, ymm_two_w3_real, ymm_half)
    store_ymm_result(out_real[3], ymm_two_w3_real)
    VMULPS(ymm_two_w5_real, ymm_two_w5_real, ymm_half)
    store_ymm_result(out_real[5], ymm_two_w5_real)

    VMULPS(ymm_two_w3_imag, ymm_two_w3_imag, ymm_half)
    store_ymm_result(out_imag[3], ymm_two_w3_imag)
    VMULPS(ymm_two_w5_imag, ymm_two_w5_imag, ymm_half)
    store_ymm_result(out_imag[5], ymm_two_w5_imag)
예제 #11
0
def inverse_vfft(reg_t0, reg_t8, reg_t_stride, data_in, reg_row_start=None, reg_row_end=None, store_mask=None):
    assert isinstance(reg_t0, GeneralPurposeRegister64)
    assert isinstance(reg_t8, GeneralPurposeRegister64)
    assert isinstance(reg_t_stride, GeneralPurposeRegister64)
    assert isinstance(data_in, list) and len(data_in) == 16
    assert reg_row_end is None or isinstance(reg_row_end, GeneralPurposeRegister32)
    assert store_mask is None or isinstance(store_mask, LocalVariable) and store_mask.size == YMMRegister.size

    in_real, in_imag = data_in[0::2], data_in[1::2]

    ymm_scale_factor = YMMRegister()
    VMOVAPS(ymm_scale_factor, Constant.float32x8(0.0625))

    ymm_W1_real, ymm_W1_imag = YMMRegister(), YMMRegister()
    VMULPS(ymm_W1_real, ymm_scale_factor, in_real[1])
    VMULPS(ymm_W1_imag, ymm_scale_factor, in_imag[1])

    ymm_W2_real, ymm_W2_imag = YMMRegister(), YMMRegister()
    VMULPS(ymm_W2_real, ymm_scale_factor, in_real[2])
    VMULPS(ymm_W2_imag, ymm_scale_factor, in_imag[2])

    ymm_W3_real, ymm_W3_imag = YMMRegister(), YMMRegister()
    VMULPS(ymm_W3_real, ymm_scale_factor, in_real[3])
    VMULPS(ymm_W3_imag, ymm_scale_factor, in_imag[3])

    # G[n].real, H[n].real = W[n].real + W[8-n].real, W[n].real - W[8-n].real
    # G[n].imag, H[n].imag = W[n].imag - W[8-n].imag, W[n].imag + W[8-n].imag
    ymm_W7_real, ymm_W7_imag = YMMRegister(), YMMRegister()
    VMOVUPS(ymm_W7_real, in_real[7])
    ymm_G1_real, ymm_H1_real = butterfly(ymm_W1_real, ymm_W7_real, scale_b=ymm_scale_factor)
    VMOVUPS(ymm_W7_imag, in_imag[7])
    ymm_G1_imag, ymm_H1_imag = butterfly(ymm_W1_imag, ymm_W7_imag, scale_b=ymm_scale_factor, negate_b=True)

    ymm_W6_real, ymm_W6_imag = YMMRegister(), YMMRegister()
    VMOVUPS(ymm_W6_real, in_real[6])
    ymm_G2_real, ymm_H2_real = butterfly(ymm_W2_real, ymm_W6_real, scale_b=ymm_scale_factor)
    VMOVUPS(ymm_W6_imag, in_imag[6])
    ymm_G2_imag, ymm_H2_imag = butterfly(ymm_W2_imag, ymm_W6_imag, scale_b=ymm_scale_factor, negate_b=True)

    ymm_W5_real, ymm_W5_imag = YMMRegister(), YMMRegister()
    VMOVUPS(ymm_W5_real, in_real[5])
    ymm_G3_real, ymm_H3_real = butterfly(ymm_W3_real, ymm_W5_real, scale_b=ymm_scale_factor)
    VMOVUPS(ymm_W5_imag, in_imag[5])
    ymm_G3_imag, ymm_H3_imag = butterfly(ymm_W3_imag, ymm_W5_imag, scale_b=ymm_scale_factor, negate_b=True)

    # H[2]+, H[2]- = H[2].real + H[2].imag, H[2].real - H[2].imag
    ymm_H2_add, ymm_H2_sub = butterfly(ymm_H2_real, ymm_H2_imag)

    # w[   n].real =  G[  n].real - H[  n].real * cos((N-n)*pi/2N) - H[  n].imag * cos(n*pi/2N)
    # w[2N-n].real =  G[  n].real + H[  n].real * cos((N-n)*pi/2N) + H[  n].imag * cos(n*pi/2N)
    # w[   n].imag =  G[  n].imag + H[  n].real * cos(n*pi/2N)     - H[  n].imag * cos((N-n)*pi/2N)
    # w[2N-n].imag = -G[  n].imag + H[  n].real * cos(n*pi/2N)     - H[  n].imag * cos((N-n)*pi/2N)
    # w[ N-n].real =  G[N-n].real - H[N-n].real * cos(n*pi/2N)     - H[N-n].imag * cos((N-n)*pi/2N)
    # w[ N+n].real =  G[N-n].real + H[N-n].real * cos(n*pi/2N)     + H[N-n].imag * cos((N-n)*pi/2N)
    # w[ N-n].imag =  G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N)
    # w[ N+n].imag = -G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N)

    ymm_cos_1pi_over_8, ymm_cos_3pi_over_8 = YMMRegister(), YMMRegister()
    VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3]))
    VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1]))

    ymm_w1_real, ymm_w7_real = YMMRegister(), ymm_G1_real
    VMOVAPS(ymm_w1_real, ymm_G1_real)
    VFNMADD231PS(ymm_w1_real, ymm_H1_real, ymm_cos_3pi_over_8)
    VFMADD231PS(ymm_w7_real, ymm_H1_real, ymm_cos_3pi_over_8)

    ymm_w1_imag, ymm_w7_imag = YMMRegister(), ymm_G1_imag
    VMOVAPS(ymm_w1_imag, ymm_G1_imag)
    VFMADD231PS(ymm_w1_imag, ymm_H1_real, ymm_cos_1pi_over_8)
    VFMSUB231PS(ymm_w7_imag, ymm_H1_real, ymm_cos_1pi_over_8)

    ymm_w3_real, ymm_w5_real = YMMRegister(), ymm_G3_real
    VMOVAPS(ymm_w3_real, ymm_G3_real)
    VFNMADD231PS(ymm_w3_real, ymm_H3_real, ymm_cos_1pi_over_8)
    VFMADD231PS(ymm_w5_real, ymm_H3_real, ymm_cos_1pi_over_8)

    ymm_w3_imag, ymm_w5_imag = YMMRegister(), ymm_G3_imag
    VMOVAPS(ymm_w3_imag, ymm_G3_imag)
    VFMADD231PS(ymm_w3_imag, ymm_H3_real, ymm_cos_3pi_over_8)
    VFMSUB231PS(ymm_w5_imag, ymm_H3_real, ymm_cos_3pi_over_8)

    ymm_sqrt2_over_2 = YMMRegister()
    VMOVAPS(ymm_sqrt2_over_2, Constant.float32x8(sqrt2_over_2))

    # w[ N/2].real =  G[N/2].real - H[N/2]+ * sqrt(2)/2
    # w[ N/2].imag =  G[N/2].imag + H[N/2]- * sqrt(2)/2
    # w[3N/2].real =  G[N/2].real + H[N/2]+ * sqrt(2)/2
    # w[3N/2].imag = -G[N/2].imag + H[N/2]- * sqrt(2)/2
    ymm_w2_real, ymm_w6_real = YMMRegister(), ymm_G2_real
    VMOVAPS(ymm_w2_real, ymm_G2_real)
    VFNMADD231PS(ymm_w2_real, ymm_H2_add, ymm_sqrt2_over_2)
    VFMADD231PS(ymm_w6_real, ymm_H2_add, ymm_sqrt2_over_2)

    ymm_w2_imag, ymm_w6_imag = YMMRegister(), ymm_G2_imag
    VMOVAPS(ymm_w2_imag, ymm_G2_imag)
    VFMADD231PS(ymm_w2_imag, ymm_H2_sub, ymm_sqrt2_over_2)
    VFMSUB231PS(ymm_w6_imag, ymm_H2_sub, ymm_sqrt2_over_2)

    # w[   n].real =  G[  n].real - H[  n].real * cos((N-n)*pi/2N) - H[  n].imag * cos(n*pi/2N)
    # w[2N-n].real =  G[  n].real + H[  n].real * cos((N-n)*pi/2N) + H[  n].imag * cos(n*pi/2N)
    # w[   n].imag =  G[  n].imag + H[  n].real * cos(n*pi/2N)     - H[  n].imag * cos((N-n)*pi/2N)
    # w[2N-n].imag = -G[  n].imag + H[  n].real * cos(n*pi/2N)     - H[  n].imag * cos((N-n)*pi/2N)
    # w[ N-n].real =  G[N-n].real - H[N-n].real * cos(n*pi/2N)     - H[N-n].imag * cos((N-n)*pi/2N)
    # w[ N+n].real =  G[N-n].real + H[N-n].real * cos(n*pi/2N)     + H[N-n].imag * cos((N-n)*pi/2N)
    # w[ N-n].imag =  G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N)
    # w[ N+n].imag = -G[N-n].imag + H[N-n].real * cos((N-n)*pi/2N) - H[N-n].imag * cos(n*pi/2N)

    ymm_cos_1pi_over_8, ymm_cos_3pi_over_8 = YMMRegister(), YMMRegister()
    VMOVAPS(ymm_cos_1pi_over_8, Constant.float32x8(cos_npi_over_8[1]))
    VMOVAPS(ymm_cos_3pi_over_8, Constant.float32x8(cos_npi_over_8[3]))

    VFNMADD231PS(ymm_w1_real, ymm_H1_imag, ymm_cos_1pi_over_8)
    VFMADD231PS(ymm_w7_real, ymm_H1_imag, ymm_cos_1pi_over_8)
    VFNMADD231PS(ymm_w1_imag, ymm_H1_imag, ymm_cos_3pi_over_8)
    VFNMADD231PS(ymm_w7_imag, ymm_H1_imag, ymm_cos_3pi_over_8)
    VFNMADD231PS(ymm_w3_real, ymm_H3_imag, ymm_cos_3pi_over_8)
    VFMADD231PS(ymm_w5_real, ymm_H3_imag, ymm_cos_3pi_over_8)
    VFNMADD231PS(ymm_w3_imag, ymm_H3_imag, ymm_cos_1pi_over_8)
    VFNMADD231PS(ymm_w5_imag, ymm_H3_imag, ymm_cos_1pi_over_8)

    data = [
        LocalVariable(YMMRegister.size), YMMRegister(),
        ymm_w1_real, ymm_w1_imag,
        ymm_w2_real, ymm_w2_imag,
        ymm_w3_real, ymm_w3_imag,
        LocalVariable(YMMRegister.size), LocalVariable(YMMRegister.size),
        ymm_w5_real, ymm_w5_imag,
        ymm_w6_real, ymm_w6_imag,
        ymm_w7_real, ymm_w7_imag
    ]
    real, imag = data[0::2], data[1::2]

    # TODO: optimize
    ymm_w0_real, ymm_w0_imag = YMMRegister(), imag[0]
    VMOVUPS(ymm_w0_real, in_real[0])
    VMOVUPS(ymm_w0_imag, in_imag[0])
    VMULPS(ymm_w0_real, ymm_w0_real, Constant.float32x8(0.0625))
    butterfly(ymm_w0_real, ymm_w0_imag, scale_b=Constant.float32x8(0.0625))
    VMOVAPS(real[0], ymm_w0_real)

    # TODO: optimize
    ymm_w4_real, ymm_w4_imag = YMMRegister(), YMMRegister()
    VMOVUPS(ymm_w4_real, in_real[4])
    VMOVUPS(ymm_w4_imag, in_imag[4])
    VMULPS(ymm_w4_real, ymm_w4_real, Constant.float32x8(0.125))
    VMULPS(ymm_w4_imag, ymm_w4_imag, Constant.float32x8(-0.125))
    VMOVAPS(real[4], ymm_w4_real)
    VMOVAPS(imag[4], ymm_w4_imag)

    # Bit reversal
    for i in range(8):
        new_i = fft8_bitreverse(i)
        if new_i > i:
            real[i], real[new_i] = real[new_i], real[i]
            imag[i], imag[new_i] = imag[new_i], imag[i]
    data = interleave(real, imag)

    # 4x FFT2: butterfly
    for i, (data_lo, data_hi) in enumerate(zip(data[0:2] + data[4:6] + data[8:10] + data[12:14], data[2:4] + data[6:8] + data[10:12] + data[14:16])):
        butterfly(data_lo, data_hi)

    # 2x FFT4: multiplication by twiddle factors
    fft4_scale_b, fft4_negate_b = {}, {}
    fft8_scale_b, fft8_negate_b = {}, {}

    # w3.re, w3.im = -w3.im, w3.re
    # w7.re, w7.im = -w7.im, w7.re
    SWAP.REGISTERS(real[3], imag[3])
    fft4_negate_b[id(real[3])] = True
    SWAP.REGISTERS(real[7], imag[7])
    fft4_negate_b[id(real[7])] = True

    # 2x FFT4: butterfly
    for data_lo, data_hi in zip(data[0:4] + data[8:12], data[4:8] + data[12:16]):
        butterfly(data_lo, data_hi, negate_b=fft4_negate_b.get(id(data_hi), False))

    # FFT8: multiplication by twiddle factors

    # w6.re, w6.im = -w6.im, w6.re
    SWAP.REGISTERS(real[6], imag[6])
    fft8_negate_b[id(real[6])] = True

    # w5.re, w5.im =  SQRT2_OVER_2 * (w5.re - w5.im), SQRT2_OVER_2 * (w5.re + w5.im)
    butterfly(real[5], imag[5], negate_b=True)
    fft8_scale_b[id(real[5])] = Constant.float32x8(sqrt2_over_2)
    fft8_scale_b[id(imag[5])] = Constant.float32x8(sqrt2_over_2)

    # w7.re, w7.im = -SQRT2_OVER_2 * (w7.re + w7.im), SQRT2_OVER_2 * (w7.re - w7.im)
    butterfly(real[7], imag[7])
    fft8_scale_b[id(real[7])] = Constant.float32x8(sqrt2_over_2)
    fft8_negate_b[id(real[7])] = True
    fft8_scale_b[id(imag[7])] = Constant.float32x8(sqrt2_over_2)

    ymm_store_mask = YMMRegister()
    if store_mask:
        VMOVAPS(ymm_store_mask, store_mask)

    # FFT8: butterfly
    with Block() as store_data:
        for i, (data_lo, data_hi) in enumerate(zip(data[0:8], data[8:16])):
            row_lo = i
            row_hi = row_lo + 8

            ymm_data_lo, ymm_data_hi = \
                butterfly(data_lo, data_hi,
                    scale_b=fft8_scale_b.get(id(data_hi)),
                    negate_b=fft8_negate_b.get(id(data_hi), False),
                    writeback=False)

            with Block() as store_data_lo:
                if reg_row_start:
                    CMP(reg_row_start, row_lo)
                    JA(store_data_lo.end)
                    if reg_row_end:
                        CMP(reg_row_end, row_lo)
                        JBE(store_data_lo.end)
                elif reg_row_end:
                    CMP(reg_row_end, row_lo)
                    JBE(store_data.end)
                if store_mask:
                    VMASKMOVPS([reg_t0], ymm_store_mask, ymm_data_lo)
                else:
                    VMOVUPS([reg_t0], ymm_data_lo)
                if i + 1 != 8:
                    ADD(reg_t0, reg_t_stride)

            with Block() as store_data_hi:
                if reg_row_start:
                    CMP(reg_row_start, row_hi)
                    JA(store_data_hi.end)
                if reg_row_end:
                    CMP(reg_row_end, row_hi)
                    JBE(store_data_hi.end)
                if store_mask:
                    VMASKMOVPS([reg_t8], ymm_store_mask, ymm_data_hi)
                else:
                    VMOVUPS([reg_t8], ymm_data_hi)
                if i + 1 != 8:
                    ADD(reg_t8, reg_t_stride)
def ifft16_within_rows_preprocess(ymm_wr, ymm_wi, bit_reversal=False):
    assert isinstance(ymm_wr, (list, tuple)) and len(ymm_wr) == 2 and all(isinstance(reg, YMMRegister) for reg in ymm_wr)
    assert isinstance(ymm_wi, (list, tuple)) and len(ymm_wi) == 2 and all(isinstance(reg, YMMRegister) for reg in ymm_wi)

    # w0.re, w1.re,  w2.re,  w3.re,  w4.re,  w5.re,  w6.re,  w7.re = \
    #    = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, f8.re - f9.im, f10.re - f11.im, f12.re - f13.im, f14.re - f15.im
    # w8.re, w9.re, w10.re, w11.re, w12.re, w13.re, w14.re, w15.re = \
    #    = f0.im, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, f8.re + f9.im, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im
    #
    # w0.im, w1.im,  w2.im,  w3.im,  w4.im,  w5.im,  w6.im,  w7.im = \
    #    = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im
    # w8.im, w9.im, w10.im, w11.im, w12.im, w13.im, w14.im, w15.im = \
    #    = f1.im, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, f9.re - f8.im, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im

    # Step 1.A:
    #   w0.re, w1.re, w2.re, w3.re, -, w13.re, w14.re, w15.re = \
    #      = f0.re, f2.re - f3.im, f4.re - f5.im, f6.re - f7.im, -, f6.re + f7.im, f4.re + f5.im, f2.re + f3.im
    #   w0.im, w1.im, w2.im, w3.im, -, w13.im, w14.im, w15.im = \
    #      = f1.re, f3.re + f2.im, f5.re + f4.im, f7.re + f6.im, -, f7.re - f6.im, f5.re - f4.im, f3.re - f2.im

    ymm_shuffle_02461642 = YMMRegister()
    VMOVDQA(ymm_shuffle_02461642, Constant.uint32x8(0, 2, 4, 6, 1, 6, 4, 2))
    ymm_shuffle_13570753 = YMMRegister()
    VMOVDQA(ymm_shuffle_13570753, Constant.uint32x8(1, 3, 5, 7, 0, 7, 5, 3))

    ymm_fr_02461642, ymm_fi_13570753 = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_02461642, ymm_shuffle_02461642, ymm_wr[0])
    VPERMPS(ymm_fi_13570753, ymm_shuffle_13570753, ymm_wi[0])
    VFMADD231PS(ymm_fr_02461642, ymm_fi_13570753, Constant.float32x8(0.0, -1.0, -1.0, -1.0, 0.0, +1.0, +1.0, +1.0))

    ymm_fr_13570753, ymm_fi_02461642 = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_13570753, ymm_shuffle_13570753, ymm_wr[0])
    VPERMPS(ymm_fi_02461642, ymm_shuffle_02461642, ymm_wi[0])
    VFMADD231PS(ymm_fr_13570753, ymm_fi_02461642, Constant.float32x8(0.0, +1.0, +1.0, +1.0, 0.0, -1.0, -1.0, -1.0))

    ymm_wr_0123xDEF, ymm_wi_0123xDEF = ymm_fr_02461642, ymm_fr_13570753

    # Step 1.B:
    #   -, w9.re, w10.re, w11.re, w4.re, w5.re, w6.re, w7.re = \
    #      = -, f14.re + f15.im, f12.re + f13.im, f10.re + f11.im, r8.re - r9.im, r10.re - r11.im, r12.re - r13.im, r14.re - f15.im
    #   -, w9.im, w10.im, w11.im, w4.im, w5.im, w6.im, w7.im = \
    #      = -, f15.re - f14.im, f13.re - f12.im, f11.re - f10.im, r9.re + f8.im, f11.re + f10.im, f13.re + f12.im, f15.re + f14.im

    ymm_shuffle_06420246 = YMMRegister()
    VMOVDQA(ymm_shuffle_06420246, Constant.uint32x8(0, 6, 4, 2, 0, 2, 4, 6))
    ymm_shuffle_17531357 = YMMRegister()
    VMOVDQA(ymm_shuffle_17531357, Constant.uint32x8(1, 7, 5, 3, 1, 3, 5, 7))

    ymm_wr_xxxxCxxx, ymm_wi_xxxxCxxx = YMMRegister(), YMMRegister()
    ymm_wr_0123CDEF, ymm_wi_0123CDEF = YMMRegister(), YMMRegister()

    ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_8ECA8ACE, ymm_shuffle_06420246, ymm_wr[1])
    VPERMPS(ymm_fi_9FDB9BDF, ymm_shuffle_17531357, ymm_wi[1])
    VADDPS(ymm_wr_xxxxCxxx, ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF)
    VFMADD231PS(ymm_fr_8ECA8ACE, ymm_fi_9FDB9BDF, Constant.float32x8(0.0, +1.0, +1.0, +1.0, -1.0, -1.0, -1.0, -1.0))
    VBLENDPS(ymm_wr_0123CDEF, ymm_wr_0123xDEF, ymm_wr_xxxxCxxx, 0b00010000)

    ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE = YMMRegister(), YMMRegister()
    VPERMPS(ymm_fr_9FDB9BDF, ymm_shuffle_17531357, ymm_wr[1])
    VPERMPS(ymm_fi_8ECA8ACE, ymm_shuffle_06420246, ymm_wi[1])
    VSUBPS(ymm_wi_xxxxCxxx, ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE)
    VFMADD231PS(ymm_fr_9FDB9BDF, ymm_fi_8ECA8ACE, Constant.float32x8(0.0, -1.0, -1.0, -1.0, +1.0, +1.0, +1.0, +1.0))
    VBLENDPS(ymm_wi_0123CDEF, ymm_wi_0123xDEF, ymm_wi_xxxxCxxx, 0b00010000)

    ymm_wr_x9AB4567, ymm_wi_x9AB4567 = ymm_fr_8ECA8ACE, ymm_fr_9FDB9BDF
    ymm_wr_89AB4567, ymm_wi_89AB4567 = YMMRegister(), YMMRegister()
    VBLENDPS(ymm_wr_89AB4567, ymm_wr_x9AB4567, ymm_fi_02461642, 0b00000001)
    VBLENDPS(ymm_wi_89AB4567, ymm_wi_x9AB4567, ymm_fi_13570753, 0b00000001)

    ymm_wr_01234567, ymm_wr_89ABCDEF = YMMRegister(), YMMRegister()
    VBLENDPS(ymm_wr_01234567, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0xF0)
    VBLENDPS(ymm_wr_89ABCDEF, ymm_wr_0123CDEF, ymm_wr_89AB4567, 0x0F)

    ymm_wi_01234567, ymm_wi_89ABCDEF = YMMRegister(), YMMRegister()
    VBLENDPS(ymm_wi_01234567, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0xF0)
    VBLENDPS(ymm_wi_89ABCDEF, ymm_wi_0123CDEF, ymm_wi_89AB4567, 0x0F)

    SWAP.REGISTERS(ymm_wr[0], ymm_wr_01234567)
    SWAP.REGISTERS(ymm_wi[0], ymm_wi_01234567)
    SWAP.REGISTERS(ymm_wr[1], ymm_wr_89ABCDEF)
    SWAP.REGISTERS(ymm_wi[1], ymm_wi_89ABCDEF)

    if bit_reversal:
        # Bit reversal
        # w[0] = x0 x8 x4 x12 x2 x10 x6 x14
        # w[1] = x1 x9 x5 x13 x3 x11 x7 x15
        ymm_bit_reversal_mask = YMMRegister()
        VMOVDQA(ymm_bit_reversal_mask, Constant.uint32x8(0, 2, 4, 6, 1, 3, 5, 7))
        for ymm in interleave(ymm_wr, ymm_wi):
            VPERMPS(ymm, ymm_bit_reversal_mask, ymm)
예제 #13
0
def train_mix(args, labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch,
              use_cuda, target_disb, emp_distb_u, pseudo_orig, pseudo_refine):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()

    bar = Bar('Training', max=args.val_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)

    model.train()
    for batch_idx in range(args.val_iteration):
        try:
            inputs_x, targets_x, _ = labeled_train_iter.next()
        except:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x, targets_x, _ = labeled_train_iter.next()

        try:
            (inputs_u, inputs_u2), _, idx_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            (inputs_u, inputs_u2), _, idx_u = unlabeled_train_iter.next()

        # Measure data loading time
        data_time.update(time.time() - end)
        batch_size = inputs_x.size(0)

        # Transform label to one-hot
        targets_x = torch.zeros(batch_size, args.num_class).scatter_(1, targets_x.view(-1,1), 1)
        if use_cuda:
            inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
            inputs_u, inputs_u2  = inputs_u.cuda(), inputs_u2.cuda()

        # Generate the pseudo labels by aggregation and sharpening
        with torch.no_grad():
            outputs_u, _ = model(inputs_u)
            outputs_u2, _ = model(inputs_u2)
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2

            pt = p ** (1 / args.T)
            targets_u = pt / pt.sum(dim=1, keepdim=True)

            # Update the saved predictions with current one
            p = targets_u
            pseudo_orig[idx_u, :] = p.data.cpu()
            pseudo_orig_backup = pseudo_orig.clone()

            # Applying DARP
            if args.darp and epoch > args.warm:
                if batch_idx % args.num_iter == 0:
                    # Iterative normalization
                    targets_u, weights_u = estimate_pseudo(target_disb, pseudo_orig, args.num_class, args.alpha)
                    scale_term = targets_u * weights_u.reshape(1, -1)
                    pseudo_orig = (pseudo_orig * scale_term + 1e-6) \
                                      / (pseudo_orig * scale_term + 1e-6).sum(dim=1, keepdim=True)

                    if args.dataset == 'stl10' or args.dataset == 'cifar100':
                        opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.3)
                    else:
                        opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.1)

                    # Updated pseudo-labels are saved
                    pseudo_refine = opt_res

                    print(pseudo_refine.sum(dim=0))
                    print(target_disb)

                    # Select
                    targets_u = opt_res[idx_u].detach().cuda()
                    pseudo_orig = pseudo_orig_backup
                else:
                    # Using previously saved pseudo-labels
                    targets_u = pseudo_refine[idx_u].cuda()

        # Mixup
        all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)

        l = np.random.beta(args.mix_alpha, args.mix_alpha)
        l = max(l, 1-l)
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        logits = [model(mixed_input[0])[0]]
        for input in mixed_input[1:]:
            logits.append(model(input)[0])

        # put interleaved samples back
        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        Lx, Lu, w = criterion(args, logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:],
                              epoch+batch_idx/args.val_iteration)
        Lu *= w
        loss = Lx + Lu

        # record loss
        losses.update(loss.item(), inputs_x.size(0))
        losses_x.update(Lx.item(), inputs_x.size(0))
        losses_u.update(Lu.item(), inputs_x.size(0))
        ws.update(w, inputs_x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                      'Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f}'.format(
                    batch=batch_idx + 1,
                    size=args.val_iteration,
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    loss_x=losses_x.avg,
                    loss_u=losses_u.avg,
                    )
        bar.next()
    bar.finish()

    return (losses.avg, losses_x.avg, losses_u.avg, emp_distb_u, pseudo_orig, pseudo_refine)
예제 #14
0
def train_remix(args, labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch,
                use_cuda, target_disb, emp_distb_u, pseudo_orig, pseudo_refine):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    losses_r = AverageMeter()
    losses_e = AverageMeter()
    ws = AverageMeter()
    end = time.time()

    bar = Bar('Training', max=args.val_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)

    model.train()
    for batch_idx in range(args.val_iteration):
        try:
            inputs_x, targets_x, _ = labeled_train_iter.next()
        except:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x, targets_x, _ = labeled_train_iter.next()

        try:
            (inputs_u, inputs_u2, inputs_u3), _, idx_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            (inputs_u, inputs_u2, inputs_u3), _, idx_u = unlabeled_train_iter.next()

        # Measure data loading time
        data_time.update(time.time() - end)
        batch_size = inputs_x.size(0)

        # Transform label to one-hot
        targets_x = torch.zeros(batch_size, args.num_class).scatter_(1, targets_x.view(-1,1), 1)
        if use_cuda:
            inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
            inputs_u, inputs_u2, inputs_u3  = inputs_u.cuda(), inputs_u2.cuda(), inputs_u3.cuda()

        # Rotate images
        temp = []
        targets_r = torch.randint(0, 4, (inputs_u2.size(0),)).long()
        for i in range(inputs_u2.size(0)):
            inputs_rot = torch.rot90(inputs_u2[i], targets_r[i], [1, 2]).reshape(1, 3, 32, 32)
            temp.append(inputs_rot)
        inputs_r = torch.cat(temp, 0)
        targets_r = torch.zeros(batch_size, 4).scatter_(1, targets_r.view(-1, 1), 1)
        inputs_r, targets_r = inputs_r.cuda(), targets_r.cuda(non_blocking=True)

        # Generate the pseudo labels
        with torch.no_grad():
            outputs_u, _ = model(inputs_u)
            p = torch.softmax(outputs_u, dim=1)

            # Tracking the empirical distribution on the unlabeled samples (ReMixMatch)
            real_batch_idx = batch_idx + epoch * args.val_iteration
            if real_batch_idx == 0:
                emp_distb_u = p.mean(0, keepdim=True)
            elif real_batch_idx // 128 == 0:
                emp_distb_u = torch.cat([emp_distb_u, p.mean(0, keepdim=True)], 0)
            else:
                emp_distb_u = emp_distb_u[-127:]
                emp_distb_u = torch.cat([emp_distb_u, p.mean(0, keepdim=True)], 0)

            # Distribution alignment
            if args.align:
                pa = p * (target_disb.cuda() + 1e-6) / (emp_distb_u.mean(0).cuda() + 1e-6)
                p = pa / pa.sum(dim=1, keepdim=True)

            # Temperature scailing
            pt = p ** (1 / args.T)
            targets_u = (pt / pt.sum(dim=1, keepdim=True)).detach()

            # Update the saved predictions with current one
            p = targets_u
            pseudo_orig[idx_u, :] = p.data.cpu()
            pseudo_orig_backup = pseudo_orig.clone()

            # Applying DARP
            if args.darp and epoch > args.warm:
                if batch_idx % args.num_iter == 0:
                    # Iterative normalization
                    targets_u, weights_u = estimate_pseudo(target_disb, pseudo_orig, args.num_class, args.alpha)
                    scale_term = targets_u * weights_u.reshape(1, -1)
                    pseudo_orig = (pseudo_orig * scale_term + 1e-6) \
                                      / (pseudo_orig * scale_term + 1e-6).sum(dim=1, keepdim=True)

                    if args.dataset == 'stl10' or args.dataset == 'cifar100':
                        opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.3)
                    else:
                        opt_res = opt_solver(pseudo_orig, target_disb, args.iter_T, 0.1)

                    # Updated pseudo-labels are saved
                    pseudo_refine = opt_res

                    # Select
                    targets_u = opt_res[idx_u].detach().cuda()
                    pseudo_orig = pseudo_orig_backup
                else:
                    # Using previously saved pseudo-labels
                    targets_u = pseudo_refine[idx_u].cuda()

        # Mixup
        all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2, inputs_u3], dim=0)
        all_targets = torch.cat([targets_x, targets_u, targets_u, targets_u], dim=0)

        l = np.random.beta(args.mix_alpha, args.mix_alpha)
        l = max(l, 1-l)
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        logits = [model(mixed_input[0])[0]]
        for input in mixed_input[1:]:
            logits.append(model(input)[0])

        # put interleaved samples back
        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        Lx, Lu, w = criterion(args, logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:],
                              epoch+batch_idx/args.val_iteration)
        _, logits_r = model(inputs_r)
        Lu *= w
        Lr = -1 * torch.mean(torch.sum(F.log_softmax(logits_r, dim=1) * targets_r, dim=1))
        Lr *= args.w_rot

        # Entropy minimization for unlabeled samples (strong augmented)
        outputs_u2, _ = model(inputs_u2)
        Le = -1 * torch.mean(torch.sum(F.log_softmax(outputs_u2, dim=1) * targets_u, dim=1))
        Le *= args.w_ent * linear_rampup(epoch+batch_idx/args.val_iteration, args.epochs)
        loss = Lx + Lu + Lr + Le

        # record loss
        losses.update(loss.item(), inputs_x.size(0))
        losses_x.update(Lx.item(), inputs_x.size(0))
        losses_u.update(Lu.item(), inputs_x.size(0))
        losses_r.update(Lr.item(), inputs_x.size(0))
        losses_e.update(Le.item(), inputs_x.size(0))
        ws.update(w, inputs_x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                      'Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | Loss_r: {loss_r:.4f} | ' \
                      'Loss_e: {loss_e:.4f}'.format(
                    batch=batch_idx + 1,
                    size=args.val_iteration,
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    loss_x=losses_x.avg,
                    loss_u=losses_u.avg,
                    loss_r=losses_r.avg,
                    loss_e=losses_e.avg,
                    )
        bar.next()
    bar.finish()

    return (losses.avg, losses_x.avg, losses_u.avg, emp_distb_u, pseudo_orig, pseudo_refine)
예제 #15
0
def amp_out(time):
    """
    A double-decay diminuendo.
    """
    dest, times, curves = split(amp_in(time), 3)
    return interleave(dest[-1::-1], [0.] + times[-1:0:-1], curves)
예제 #16
0
파일: c8gemm.py 프로젝트: ominux/NNPACK
def cgemm_loop(ymm_c, reg_a, reg_b, reg_k, step_k, loop, conjugate_b, mixed_columns):
	ymm_c_real, ymm_c_imag = ymm_c
	assert isinstance(reg_k, GeneralPurposeRegister64)
	assert isinstance(step_k, int) and step_k >= 1
	assert isinstance(mixed_columns, bool)
	assert isinstance(loop, Loop)

	assert isinstance(ymm_c_real, list) and isinstance(ymm_c_imag, list) and len(ymm_c_real) == len(ymm_c_imag)
	mr = len(ymm_c_real)
	assert isinstance(ymm_c_real[0], list)
	nr = len(ymm_c_real[0])
	assert all(isinstance(ymm_c_real_m, list) and len(ymm_c_real_m) == nr for ymm_c_real_m in ymm_c_real)
	assert all(isinstance(ymm_c_imag_m, list) and len(ymm_c_imag_m) == nr for ymm_c_imag_m in ymm_c_imag)

	step_a, step_b = mr * step_k * YMMRegister.size * 2, nr * step_k * YMMRegister.size * 2
	disp_shift_a = 0 if step_a <= 128 else -128
	disp_shift_b = 0 if step_b <= 128 else -128

	ymm_a_real, ymm_a_imag = tuple([YMMRegister() for m in range(mr)] for c in range(2))
	ymm_b_real, ymm_b_imag = tuple([YMMRegister() for n in range(nr)] for c in range(2))

	use_disp_shift = False

	if step_k > 1:
		if disp_shift_a != 0:
			SUB(reg_a, disp_shift_a)
		if disp_shift_b != 0:
			SUB(reg_b, disp_shift_b)
		SUB(reg_k, step_k)
		JB(loop.end)

	with loop:
		for k in range(step_k):
			for i, ymm_a in enumerate(interleave(ymm_a_real, ymm_a_imag)):
				VMOVAPS(ymm_a, [reg_a + (i + 2*mr*k) * YMMRegister.size + disp_shift_a])

			for i, ymm_b in enumerate(interleave(ymm_b_real, ymm_b_imag)):
				VMOVAPS(ymm_b, [reg_b + (i + 2*nr*k) * YMMRegister.size + disp_shift_b])

			for n in range(nr):
				for m in range(mr):
					VFMADD231PS(ymm_c_real[m][n], ymm_a_real[m], ymm_b_real[n])

				if mixed_columns:
					VBLENDPS(ymm_b_real[n], ymm_b_real[n], ymm_b_imag[n], 0b00000011)
				for m in range(mr):
					VFMADD231PS(ymm_c_imag[m][n], ymm_a_imag[m], ymm_b_real[n])

			if nr > 1 and mixed_columns:
				zero_columns01_mask = YMMRegister()
				VMOVAPS(zero_columns01_mask, Constant.uint32x8(0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF))
			else:
				zero_columns01_mask = Constant.uint32x8(0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)

			# if step_k > 1:
			# 	PREFETCHNTA([reg_a + 640])

			for n in range(nr):
				if mixed_columns:
					VANDPS(ymm_b_imag[n], ymm_b_imag[n], zero_columns01_mask)
				for m in range(mr):
					if conjugate_b:
						VFMADD231PS(ymm_c_real[m][n], ymm_a_imag[m], ymm_b_imag[n])
						VFNMADD231PS(ymm_c_imag[m][n], ymm_a_real[m], ymm_b_imag[n])
					else:
						VFNMADD231PS(ymm_c_real[m][n], ymm_a_imag[m], ymm_b_imag[n])
						VFMADD231PS(ymm_c_imag[m][n], ymm_a_real[m], ymm_b_imag[n])

		SUB(reg_a, -step_a)
		SUB(reg_b, -step_b)
		if step_k > 1:
			SUB(reg_k, step_k)
			JAE(loop.begin)
		else:
			DEC(reg_k)
			JNE(loop.begin)

	if step_k > 1:
		if disp_shift_a:
			ADD(reg_a, disp_shift_a)
		if disp_shift_b:
			ADD(reg_b, disp_shift_b)
		ADD(reg_k, step_k)