예제 #1
0
 def mask_to_register(mask):
     mask = Mask.as_immediate(mask)
     if mask in maskcache:
         maskcache.move_to_end(mask)
         return maskcache[mask]
     try:
         maskreg = MaskRegister(64, mask)
     except AllocationError:
         _, maskreg = maskcache.popitem(False)
     x86.mov(maskreg, mask)
     maskcache[mask] = maskreg
     return maskreg
예제 #2
0
def square_701_shufbytes(out_data, in_data, n):
    r = Register()
    out = [Register() for _ in range(3)]
    moved = [False] * 3

    t1 = Register()
    t2 = Register()
    t3 = Register()
    t4 = Register()
    t5 = Register()

    seq = gen_sequence(n, 701) + 67 * [ZERO]
    seq_regvalues = split_in_size_n(seq, 256)

    for in_data_fragment in in_data:
        x86.vmovdqa(r, in_data_fragment)
        shift_in = shifted = r
        offset = 0
        for delta in range(8):  # 8 possible rotations may be necessary
            rol_meta = None
            if delta > 0:
                # if we've made the previous rotation persistent
                if shift_in is shifted:
                    shifted = t4 if shifted is t3 else t3
                d_nett = delta - offset
                rol_meta = len(x86.INSTRUCTIONS), str(shifted), str(t1)
                x86.macro_v256rol(shifted, shift_in, d_nett, t1, t2)
                rotated = [b for d in range(d_nett) for b in shifted[d::64]]
            # vpshufb cannot cross over xmm lanes
            for swap_xmms in [False, True]:
                if swap_xmms:
                    swapped = t5
                    x86.vpermq(swapped, shifted, '01001110')
                else:
                    swapped = shifted
                r_bytes = split_in_size_n(swapped, 8)
                while True:  # could be necessary to extract twice from same r
                    bitmask = [[] for _ in range(len(seq_regvalues))]
                    shufmask = [None] * 32
                    for k, seq_value in enumerate(seq_regvalues):
                        s_bytes = split_in_size_n(seq_value, 8)
                        s_xmms = split_in_size_n(s_bytes, 16)
                        r_xmms = split_in_size_n(r_bytes, 16)
                        for i, (s128, r128) in enumerate(zip(s_xmms, r_xmms)):
                            for l, s_byte in enumerate(s128):
                                for m, r_byte in enumerate(r128):
                                    # if this byte is already taken;
                                    if (shufmask[i * 16 + l] is not None
                                            and shufmask[i * 16 + l] != m):
                                        continue
                                    bits = [
                                        ONE if x == y and x != ZERO else ZERO
                                        for x, y in zip(r_byte, s_byte)
                                    ]
                                    if ONE not in bits:
                                        continue
                                    shufmask[i * 16 + l] = m
                                    bitmask[k] += bits
                                    break
                                else:
                                    bitmask[k] += [ZERO] * 8
                                    continue
                                for m, (x, y) in enumerate(zip(bits, s_byte)):
                                    if x == ONE:
                                        seq_regvalues[k][i * 128 + l * 8 +
                                                         m] = None
                                s_bytes = split_in_size_n(seq_regvalues[k], 8)
                    if all(x is None for x in shufmask):
                        break
                    x86.vpshufb(t2, swapped, IndicesMask(shufmask))
                    for k, seq_value in enumerate(seq_regvalues):
                        if ONE not in bitmask[k]:
                            continue
                        if not moved[k]:
                            x86.vpand(out[k], t2, Mask(bitmask[k]))
                            moved[k] = True
                        else:
                            x86.vpand(t1, t2, Mask(bitmask[k]))
                            x86.vpxor(out[k], out[k], t1)
                        # check if we used any of the rotated bits
                        for maskbit, bit in zip(bitmask[k], t2):
                            if delta > 0 and bit in rotated and maskbit is ONE:
                                rol_meta = None

            # TODO this is an ugly hack that should be abstracted
            if rol_meta is not None:
                i, dest, temp = rol_meta
                del x86.INSTRUCTIONS[i]  # delete srlq
                x86.INSTRUCTIONS[i] = x86.INSTRUCTIONS[i].replace(temp, dest)
                del x86.INSTRUCTIONS[i + 1]  # delete permq
                del x86.INSTRUCTIONS[i + 1]  # delete xor
            else:
                # if we're keeping the rotation, make it persistent so that the
                # next rotation is smaller (and thus more likely ignorable)
                shift_in = shifted
                offset = delta

    for m, r in zip(out_data, out):
        x86.vmovdqa(m, r)
예제 #3
0
def square_701_patience(out_data, in_data, n, callee_saved=0):
    x = list(range(701)) + 3 * [ZERO]
    regs = split_in_size_n(x, 64)

    seq = gen_sequence(n, 701) + 3 * [ZERO]
    seq_r = split_in_size_n(seq, 64)

    moved = [False] * len(seq_r)

    r = Register(64)
    t1 = Register(64)

    for i in range(callee_saved):
        x86.push_callee_saved(64)

    maskcache = OrderedDict()

    def mask_to_register(mask):
        mask = Mask.as_immediate(mask)
        if mask in maskcache:
            maskcache.move_to_end(mask)
            return maskcache[mask]
        try:
            maskreg = MaskRegister(64, mask)
        except AllocationError:
            _, maskreg = maskcache.popitem(False)
        x86.mov(maskreg, mask)
        maskcache[mask] = maskreg
        return maskreg

    for j, inreg in enumerate(regs):
        x86.mov(r, in_data[j])
        for i, seqreg in enumerate(seq_r):
            piledict = {}
            for rotation in range(64):
                ror_seqreg = seqreg[rotation:] + seqreg[:rotation]
                piles = []
                overlap = [x for x in ror_seqreg if x in inreg and x != ZERO]
                for x in overlap:
                    for pile in piles:
                        try:
                            if pile[-1] <= x:
                                pile.append(x)
                                break
                        except IndexError:  # pile is empty
                            pass
                    else:  # doesn't fit on any existing pile: start a new pile
                        piles.append([x])
                piledict[rotation] = piles
            min_pile_key = min(piledict, key=lambda x: len(piledict.get(x)))
            if len(piledict[0]) == len(piledict[min_pile_key]):
                min_pile_key = 0
            if min_pile_key > 0:
                ror_seqreg = seqreg[min_pile_key:] + seqreg[:min_pile_key]
            else:
                ror_seqreg = seqreg

            for pile in piledict[min_pile_key]:
                emask = [ZERO] * 64
                for bit in pile:
                    emask[inreg.index(bit)] = ONE
                dmask = [ZERO] * 64
                for bit in pile:
                    dmask[ror_seqreg.index(bit)] = ONE

                # For consecutive bits, we do not even need pext/pdep
                if (Mask.consec(dmask) and Mask.consec(emask) and
                    (Mask.degree(emask) < 32 or Mask.degree(dmask) < 32)):
                    delta = (Mask.degree(dmask) - Mask.degree(emask)) % 64
                    x86.mov(t1, r)
                    if Mask.degree(emask) < 32:
                        x86.iand(t1, Mask.as_immediate(emask))
                        x86.rol(t1, delta + min_pile_key)
                        min_pile_key = 0  # to avoid two rols
                    else:
                        x86.rol(t1, delta)
                        x86.iand(t1, Mask.as_immediate(dmask))
                else:
                    # if we can extract using AND instead..
                    if Mask.consec(emask, True) and Mask.degree(emask) < 32:
                        x86.mov(t1, r)
                        x86.iand(t1, Mask.as_immediate(emask))
                    else:
                        x86.pext(t1, r, mask_to_register(emask))
                    x86.pdep(t1, t1, mask_to_register(dmask))

                if min_pile_key > 0:
                    x86.rol(t1, min_pile_key)
                if moved[i]:  # stored per i, as it's not the outer loop
                    x86.xor(out_data[i], t1)
                else:
                    x86.mov(out_data[i], t1)
                    moved[i] = True
    x86.movq(out_data[11], 0)  # to fill up all 768 bits

    for mask in maskcache.values():
        mask.free()

    for i in range(callee_saved):
        x86.pop_callee_saved(64)
예제 #4
0
def square_1_701(out_data, in_data):
    """ Operates on MemoryFragments containing the polynomials """

    r = Register()

    r_out = [Register() for _ in range(3)]
    r_out_b = [Register() for _ in range(3)]
    r_out_c = [Register() for _ in range(3)]
    result = [Register() for _ in range(3)]

    t1 = Register()
    t2 = Register()
    t3 = Register()

    for i in range(3):
        x86.vmovdqa(r, in_data[i])
        for j in range(0, 8):
            if j > 0:
                shifted = t3
                x86.macro_v256rol(shifted, r, j, t1, t2)
            else:
                shifted = r
            if i == 0:
                if j < 4:
                    if j == 0:
                        mask = '00000001' * 32
                    if j == 1:
                        mask = '00000100' * 32
                    if j == 2:
                        mask = '00010000' * 32
                    if j == 3:
                        mask = '01000000' * 32

                    if j == 0:
                        x86.vpand(r_out[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out[i], r_out[i], t1)
                elif 4 <= j < 8:
                    if j == 4:
                        mask = '00000001' * 32
                    if j == 5:
                        mask = '00000100' * 32
                    if j == 6:
                        mask = '00010000' * 32
                    if j == 7:
                        mask = '01000000' * 32

                    if j == 4:
                        x86.vpand(r_out_b[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out_b[i], r_out_b[i], t1)
            if i == 1:
                if j < 4:
                    if j == 0:
                        mask = '00100000' * 19 + '00100000' + '00000001' * 12
                    if j == 1:
                        mask = '10000000' * 19 + '10000000' + '00000100' * 12
                    if j == 2:
                        mask = '00000000' * 19 + '00000000' + '00010000' * 12
                    if j == 3:
                        mask = '00000000' * 19 + '00000000' + '01000000' * 12

                    if j == 0:
                        x86.vpand(r_out[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out[i], r_out[i], t1)
                if 2 <= j < 6:
                    if j == 2:
                        mask = '00000010' * 19 + '00000010' + '00000000' * 11 + '00000010'  # noqa: E501
                    if j == 3:
                        mask = '00001000' * 19 + '00001000' + '00000000' * 12
                    if j == 4:
                        mask = '00100000' * 19 + '00100000' + '00000000' * 12
                    if j == 5:
                        mask = '10000000' * 19 + '10000000' + '00000000' * 12

                    if j == 2:
                        x86.vpand(r_out_b[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out_b[i], r_out_b[i], t1)
                if 4 <= j < 8:
                    if j == 4:
                        mask = '00000000' * 19 + '00000001' + '00000001' * 11 + '00000000'  # noqa: E501
                    if j == 5:
                        mask = '00000000' * 19 + '00000100' + '00000100' * 11 + '00000000'  # noqa: E501
                    if j == 6:
                        mask = '00000010' * 19 + '00010000' + '00010000' * 11 + '00000010'  # noqa: E501
                    if j == 7:
                        mask = '00001000' * 19 + '00000000' + '01000000' * 11 + '00001000'  # noqa: E501

                    if j == 4:
                        x86.vpand(r_out_c[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out_c[i], r_out_c[i], t1)

            if i == 2:
                if j < 2:
                    if j == 0:
                        mask = '00100000' * 32
                    if j == 1:
                        mask = '10000000' * 32

                    if j == 0:
                        x86.vpand(r_out[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out[i], r_out[i], t1)
                if 2 <= j < 6:
                    if j == 2:
                        mask = '00000010' * 32
                    if j == 3:
                        mask = '00001000' * 32
                    if j == 4:
                        mask = '00100000' * 32
                    if j == 5:
                        mask = '10000000' * 32

                    if j == 2:
                        x86.vpand(r_out_b[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out_b[i], r_out_b[i], t1)
                if 6 <= j < 8:
                    if j == 6:
                        mask = '00000010' * 32
                    if j == 7:
                        mask = '00001000' * 32

                    if j == 6:
                        x86.vpand(r_out_c[i], shifted, Mask(mask))
                    else:
                        x86.vpand(t1, shifted, Mask(mask))
                        x86.vpxor(r_out_c[i], r_out_c[i], t1)

    # move bit 511 to r_out_b[2]
    x86.vpand(t1, r_out_b[1], Mask('0001'))
    x86.vpxor(r_out_b[2], r_out_b[2], t1)

    highbyte_mask = Mask(2 * ('1' + '0' * 15))
    highbyte_nmask = Mask(2 * ('0' + '1' * 15))

    mask = IndicesMask(2 * sum(zip(range(8), [None] * 8), ()))
    x86.vpshufb(result[0], r_out[0], mask)

    mask = IndicesMask(2 * sum(zip([None] * 8, range(1, 9)), ()))
    x86.vpshufb(t1, r_out_b[0], mask)
    x86.vpxor(result[0], t1, result[0])

    x86.vextracti128(result[1], result[0], 1)
    mask = IndicesMask(2 * sum(zip(range(8, 16), [None] * 8), ()))
    x86.vpshufb(result[2], r_out[0], mask)

    mask = IndicesMask(2 * (sum(zip([None] * 8, range(9, 16)),
                                ()) + (None, 0)))
    x86.vpshufb(t1, r_out_b[0], mask)
    x86.vpand(t2, t1, highbyte_mask)
    x86.vpand(t1, t1, highbyte_nmask)
    x86.vpermq(t2, t2, '01001100')
    x86.vpxor(t1, t1, t2)
    x86.vpxor(result[2], t1, result[2])

    x86.vinserti128(result[0], result[0], result[2], 1)
    x86.vextracti128(t1, result[2], 1)
    x86.vinserti128(result[1], result[1], t1, 1)

    # ---

    mask = IndicesMask(
        sum(zip(range(8), [None] * 8), ()) + sum(zip(
            [None] * 8,
            range(8),
        ), ()))
    x86.vpshufb(result[2], r_out[1], mask)

    mask = IndicesMask(2 * sum(zip([None] * 8, range(1, 9)), ()))
    x86.vpshufb(t1, r_out_c[1], mask)
    x86.vpxor(result[2], t1, result[2])

    mask = IndicesMask((None, ) * 16 + sum(zip(range(8), [None] * 8), ()))
    x86.vpshufb(t1, r_out_b[1], mask)
    x86.vpxor(result[2], t1, result[2])

    x86.vextracti128(t2, result[2], 1)
    x86.vpermq(t2, t2, '11010011')
    x86.vpxor(result[0], t2, result[0])

    mask = IndicesMask(
        sum(zip(range(8, 12), [None] * 8), ()) +
        sum(zip([None] * 8, range(12, 16)), ()) +
        sum(zip(
            [None] * 8,
            range(8, 16),
        ), ()))
    x86.vpshufb(t1, r_out[1], mask)

    mask = IndicesMask(2 * (sum(zip([None] * 7, range(9, 16)),
                                ()) + (None, 0)))
    x86.vpshufb(t2, r_out_c[1], mask)
    x86.vpand(t3, t2, highbyte_mask)
    x86.vpand(t2, t2, highbyte_nmask)
    x86.vpermq(t3, t3, '01001100')
    x86.vpxor(t2, t3, t2)
    x86.vpxor(t1, t1, t2)

    mask = IndicesMask(2 * sum(zip(range(8, 16), [None] * 8), ()))
    x86.vpshufb(t2, r_out_b[1], mask)
    x86.vpxor(t2, t2, t1)
    x86.vpand(t3, t2, Mask('0001'))
    x86.vinserti128(result[2], result[2], t3, 1)
    # complete first 0-350

    x86.vpand(t1, t2, Mask('0110'))
    x86.vpermq(t1, t1, '10000001')
    x86.vpxor(result[0], result[0], t1)

    x86.vpand(t1, t2, Mask('1000'))
    x86.vpermq(t1, t1, '00000011')
    x86.vpxor(result[1], t1, result[1])

    mask = IndicesMask(2 * sum(zip(range(8), [None] * 8), ()))
    x86.vpshufb(t1, r_out_b[2], mask)

    mask = IndicesMask(2 * sum(zip([None] * 8, range(1, 9)), ()))
    x86.vpshufb(t2, r_out_c[2], mask)
    x86.vpxor(t1, t1, t2)

    mask = IndicesMask(2 * sum(zip([None] * 8, range(8)), ()))
    x86.vpshufb(t2, r_out[2], mask)
    x86.vpxor(t1, t1, t2)

    x86.vinserti128(t3, t3, t1, 0)
    x86.vpermq(t3, t3, '11010011')
    x86.vpxor(result[1], t3, result[1])

    x86.vextracti128(t3, t1, 1)
    x86.vpermq(t3, t3, '11010011')
    x86.vpxor(result[2], t3, result[2])

    mask = IndicesMask(2 * sum(zip(range(8, 16), [None] * 8), ()))
    x86.vpshufb(t1, r_out_b[2], mask)

    mask = IndicesMask(2 * sum(zip([None] * 8, range(8, 16)), ()))
    x86.vpshufb(t2, r_out[2], mask)
    x86.vpxor(t1, t1, t2)

    mask = IndicesMask(2 * (sum(zip([None] * 7, range(9, 16)),
                                ()) + (None, 0)))
    x86.vpshufb(t2, r_out_c[2], mask)
    x86.vpand(t3, t2, highbyte_mask)
    x86.vpand(t2, t2, highbyte_nmask)
    x86.vpermq(t3, t3, '01001100')
    x86.vpxor(t2, t3, t2)
    x86.vpxor(t1, t2, t1)

    x86.vpermq(t2, t1, '00111111')
    x86.vpxor(result[1], result[1], t2)

    x86.vpermq(t3, t1, '11111101')
    x86.vpxor(result[2], result[2], t3)

    for i in range(3):
        x86.vmovdqa(out_data[i], result[i])
예제 #5
0
def square_350_701(dst, src):
    """ Requires source and destination registers to be disjunct. """
    r = src
    r_out = dst

    maskreg = MaskRegister()
    lowbitmask = Mask('0' * 255 + '1')
    x86.vmovdqa(maskreg, lowbitmask)

    lowbitreg = Register()
    x86.vpand(lowbitreg, maskreg, r[0])
    x86.vpandn(r[0], maskreg, r[0])

    rest = Register()
    twobits = Register()
    nexttwobits = Register()
    mask0001 = Mask('0001')
    x86.vmovdqa(maskreg, mask0001)

    for i in range(2, -1, -1):
        x86.vpsllq(rest, r[i], 2)
        x86.vpsrlq(twobits, r[i], 62)
        x86.vpermq(twobits, twobits, '10010011')
        x86.vpand(nexttwobits, maskreg, twobits)
        x86.vpandn(twobits, maskreg, twobits)
        x86.vpxor(r[i], rest, twobits)
        if i + 1 < 3:
            x86.vpxor(r[i + 1], r[i + 1], nexttwobits)

    mask_bit_in_byte = [
        Mask(32 * ([ZERO] * i + [ONE] + [ZERO] * (7 - i))) for i in range(8)
    ]
    bits = Register()
    accum = Register()

    for i in range(2, -1, -1):
        for j in range(8):
            x86.vpand(bits, r[i], mask_bit_in_byte[j])
            if j == 0:
                x86.vpshlq(accum, bits, 7 - 2 * j)
            else:
                x86.vpshlq(bits, bits, 7 - 2 * j)
                if j == 7:
                    x86.vpxor(r[i], accum, bits)
                else:
                    x86.vpxor(accum, accum, bits)

    x86.vpermq(lowbitreg, lowbitreg, '11001111')
    x86.vpshlq(lowbitreg, lowbitreg, 56)
    x86.vpxor(r[2], lowbitreg, r[2])

    indices = IndicesMask(
        list(range(15, -1, -1)) + [None] * 8 + list(range(7, -1, -1)))
    x86.vpshufb(r_out[2], r[0], indices)
    x86.vpermq(r_out[2], r_out[2], '10010011')

    t1 = Register()

    for i in range(2):
        indices = IndicesMask([None] * 24 + list(range(15, 7, -1)))
        x86.vpshufb(r_out[1 - i], r[i], indices)
        indices = IndicesMask(
            list(range(15, -1, -1)) + list(range(7, -1, -1)) + [None] * 8)
        x86.vpshufb(t1, r[i + 1], indices)
        x86.vpxor(r_out[1 - i], t1, r_out[1 - i])
        x86.vpermq(r_out[1 - i], r_out[1 - i], '11010010')