def square_1_701(out_data, in_data): """ Operates on MemoryFragments containing the polynomials """ r = Register() r_out = [Register() for _ in range(3)] r_out_b = [Register() for _ in range(3)] r_out_c = [Register() for _ in range(3)] result = [Register() for _ in range(3)] t1 = Register() t2 = Register() t3 = Register() for i in range(3): x86.vmovdqa(r, in_data[i]) for j in range(0, 8): if j > 0: shifted = t3 x86.macro_v256rol(shifted, r, j, t1, t2) else: shifted = r if i == 0: if j < 4: if j == 0: mask = '00000001' * 32 if j == 1: mask = '00000100' * 32 if j == 2: mask = '00010000' * 32 if j == 3: mask = '01000000' * 32 if j == 0: x86.vpand(r_out[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out[i], r_out[i], t1) elif 4 <= j < 8: if j == 4: mask = '00000001' * 32 if j == 5: mask = '00000100' * 32 if j == 6: mask = '00010000' * 32 if j == 7: mask = '01000000' * 32 if j == 4: x86.vpand(r_out_b[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out_b[i], r_out_b[i], t1) if i == 1: if j < 4: if j == 0: mask = '00100000' * 19 + '00100000' + '00000001' * 12 if j == 1: mask = '10000000' * 19 + '10000000' + '00000100' * 12 if j == 2: mask = '00000000' * 19 + '00000000' + '00010000' * 12 if j == 3: mask = '00000000' * 19 + '00000000' + '01000000' * 12 if j == 0: x86.vpand(r_out[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out[i], r_out[i], t1) if 2 <= j < 6: if j == 2: mask = '00000010' * 19 + '00000010' + '00000000' * 11 + '00000010' # noqa: E501 if j == 3: mask = '00001000' * 19 + '00001000' + '00000000' * 12 if j == 4: mask = '00100000' * 19 + '00100000' + '00000000' * 12 if j == 5: mask = '10000000' * 19 + '10000000' + '00000000' * 12 if j == 2: x86.vpand(r_out_b[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out_b[i], r_out_b[i], t1) if 4 <= j < 8: if j == 4: mask = '00000000' * 19 + '00000001' + '00000001' * 11 + '00000000' # noqa: E501 if j == 5: mask = '00000000' * 19 + '00000100' + '00000100' * 11 + '00000000' # noqa: E501 if j == 6: mask = '00000010' * 19 + '00010000' + '00010000' * 11 + '00000010' # noqa: E501 if j == 7: mask = '00001000' * 19 + '00000000' + '01000000' * 11 + '00001000' # noqa: E501 if j == 4: x86.vpand(r_out_c[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out_c[i], r_out_c[i], t1) if i == 2: if j < 2: if j == 0: mask = '00100000' * 32 if j == 1: mask = '10000000' * 32 if j == 0: x86.vpand(r_out[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out[i], r_out[i], t1) if 2 <= j < 6: if j == 2: mask = '00000010' * 32 if j == 3: mask = '00001000' * 32 if j == 4: mask = '00100000' * 32 if j == 5: mask = '10000000' * 32 if j == 2: x86.vpand(r_out_b[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out_b[i], r_out_b[i], t1) if 6 <= j < 8: if j == 6: mask = '00000010' * 32 if j == 7: mask = '00001000' * 32 if j == 6: x86.vpand(r_out_c[i], shifted, Mask(mask)) else: x86.vpand(t1, shifted, Mask(mask)) x86.vpxor(r_out_c[i], r_out_c[i], t1) # move bit 511 to r_out_b[2] x86.vpand(t1, r_out_b[1], Mask('0001')) x86.vpxor(r_out_b[2], r_out_b[2], t1) highbyte_mask = Mask(2 * ('1' + '0' * 15)) highbyte_nmask = Mask(2 * ('0' + '1' * 15)) mask = IndicesMask(2 * sum(zip(range(8), [None] * 8), ())) x86.vpshufb(result[0], r_out[0], mask) mask = IndicesMask(2 * sum(zip([None] * 8, range(1, 9)), ())) x86.vpshufb(t1, r_out_b[0], mask) x86.vpxor(result[0], t1, result[0]) x86.vextracti128(result[1], result[0], 1) mask = IndicesMask(2 * sum(zip(range(8, 16), [None] * 8), ())) x86.vpshufb(result[2], r_out[0], mask) mask = IndicesMask(2 * (sum(zip([None] * 8, range(9, 16)), ()) + (None, 0))) x86.vpshufb(t1, r_out_b[0], mask) x86.vpand(t2, t1, highbyte_mask) x86.vpand(t1, t1, highbyte_nmask) x86.vpermq(t2, t2, '01001100') x86.vpxor(t1, t1, t2) x86.vpxor(result[2], t1, result[2]) x86.vinserti128(result[0], result[0], result[2], 1) x86.vextracti128(t1, result[2], 1) x86.vinserti128(result[1], result[1], t1, 1) # --- mask = IndicesMask( sum(zip(range(8), [None] * 8), ()) + sum(zip( [None] * 8, range(8), ), ())) x86.vpshufb(result[2], r_out[1], mask) mask = IndicesMask(2 * sum(zip([None] * 8, range(1, 9)), ())) x86.vpshufb(t1, r_out_c[1], mask) x86.vpxor(result[2], t1, result[2]) mask = IndicesMask((None, ) * 16 + sum(zip(range(8), [None] * 8), ())) x86.vpshufb(t1, r_out_b[1], mask) x86.vpxor(result[2], t1, result[2]) x86.vextracti128(t2, result[2], 1) x86.vpermq(t2, t2, '11010011') x86.vpxor(result[0], t2, result[0]) mask = IndicesMask( sum(zip(range(8, 12), [None] * 8), ()) + sum(zip([None] * 8, range(12, 16)), ()) + sum(zip( [None] * 8, range(8, 16), ), ())) x86.vpshufb(t1, r_out[1], mask) mask = IndicesMask(2 * (sum(zip([None] * 7, range(9, 16)), ()) + (None, 0))) x86.vpshufb(t2, r_out_c[1], mask) x86.vpand(t3, t2, highbyte_mask) x86.vpand(t2, t2, highbyte_nmask) x86.vpermq(t3, t3, '01001100') x86.vpxor(t2, t3, t2) x86.vpxor(t1, t1, t2) mask = IndicesMask(2 * sum(zip(range(8, 16), [None] * 8), ())) x86.vpshufb(t2, r_out_b[1], mask) x86.vpxor(t2, t2, t1) x86.vpand(t3, t2, Mask('0001')) x86.vinserti128(result[2], result[2], t3, 1) # complete first 0-350 x86.vpand(t1, t2, Mask('0110')) x86.vpermq(t1, t1, '10000001') x86.vpxor(result[0], result[0], t1) x86.vpand(t1, t2, Mask('1000')) x86.vpermq(t1, t1, '00000011') x86.vpxor(result[1], t1, result[1]) mask = IndicesMask(2 * sum(zip(range(8), [None] * 8), ())) x86.vpshufb(t1, r_out_b[2], mask) mask = IndicesMask(2 * sum(zip([None] * 8, range(1, 9)), ())) x86.vpshufb(t2, r_out_c[2], mask) x86.vpxor(t1, t1, t2) mask = IndicesMask(2 * sum(zip([None] * 8, range(8)), ())) x86.vpshufb(t2, r_out[2], mask) x86.vpxor(t1, t1, t2) x86.vinserti128(t3, t3, t1, 0) x86.vpermq(t3, t3, '11010011') x86.vpxor(result[1], t3, result[1]) x86.vextracti128(t3, t1, 1) x86.vpermq(t3, t3, '11010011') x86.vpxor(result[2], t3, result[2]) mask = IndicesMask(2 * sum(zip(range(8, 16), [None] * 8), ())) x86.vpshufb(t1, r_out_b[2], mask) mask = IndicesMask(2 * sum(zip([None] * 8, range(8, 16)), ())) x86.vpshufb(t2, r_out[2], mask) x86.vpxor(t1, t1, t2) mask = IndicesMask(2 * (sum(zip([None] * 7, range(9, 16)), ()) + (None, 0))) x86.vpshufb(t2, r_out_c[2], mask) x86.vpand(t3, t2, highbyte_mask) x86.vpand(t2, t2, highbyte_nmask) x86.vpermq(t3, t3, '01001100') x86.vpxor(t2, t3, t2) x86.vpxor(t1, t2, t1) x86.vpermq(t2, t1, '00111111') x86.vpxor(result[1], result[1], t2) x86.vpermq(t3, t1, '11111101') x86.vpxor(result[2], result[2], t3) for i in range(3): x86.vmovdqa(out_data[i], result[i])
def square_701_shufbytes(out_data, in_data, n): r = Register() out = [Register() for _ in range(3)] moved = [False] * 3 t1 = Register() t2 = Register() t3 = Register() t4 = Register() t5 = Register() seq = gen_sequence(n, 701) + 67 * [ZERO] seq_regvalues = split_in_size_n(seq, 256) for in_data_fragment in in_data: x86.vmovdqa(r, in_data_fragment) shift_in = shifted = r offset = 0 for delta in range(8): # 8 possible rotations may be necessary rol_meta = None if delta > 0: # if we've made the previous rotation persistent if shift_in is shifted: shifted = t4 if shifted is t3 else t3 d_nett = delta - offset rol_meta = len(x86.INSTRUCTIONS), str(shifted), str(t1) x86.macro_v256rol(shifted, shift_in, d_nett, t1, t2) rotated = [b for d in range(d_nett) for b in shifted[d::64]] # vpshufb cannot cross over xmm lanes for swap_xmms in [False, True]: if swap_xmms: swapped = t5 x86.vpermq(swapped, shifted, '01001110') else: swapped = shifted r_bytes = split_in_size_n(swapped, 8) while True: # could be necessary to extract twice from same r bitmask = [[] for _ in range(len(seq_regvalues))] shufmask = [None] * 32 for k, seq_value in enumerate(seq_regvalues): s_bytes = split_in_size_n(seq_value, 8) s_xmms = split_in_size_n(s_bytes, 16) r_xmms = split_in_size_n(r_bytes, 16) for i, (s128, r128) in enumerate(zip(s_xmms, r_xmms)): for l, s_byte in enumerate(s128): for m, r_byte in enumerate(r128): # if this byte is already taken; if (shufmask[i * 16 + l] is not None and shufmask[i * 16 + l] != m): continue bits = [ ONE if x == y and x != ZERO else ZERO for x, y in zip(r_byte, s_byte) ] if ONE not in bits: continue shufmask[i * 16 + l] = m bitmask[k] += bits break else: bitmask[k] += [ZERO] * 8 continue for m, (x, y) in enumerate(zip(bits, s_byte)): if x == ONE: seq_regvalues[k][i * 128 + l * 8 + m] = None s_bytes = split_in_size_n(seq_regvalues[k], 8) if all(x is None for x in shufmask): break x86.vpshufb(t2, swapped, IndicesMask(shufmask)) for k, seq_value in enumerate(seq_regvalues): if ONE not in bitmask[k]: continue if not moved[k]: x86.vpand(out[k], t2, Mask(bitmask[k])) moved[k] = True else: x86.vpand(t1, t2, Mask(bitmask[k])) x86.vpxor(out[k], out[k], t1) # check if we used any of the rotated bits for maskbit, bit in zip(bitmask[k], t2): if delta > 0 and bit in rotated and maskbit is ONE: rol_meta = None # TODO this is an ugly hack that should be abstracted if rol_meta is not None: i, dest, temp = rol_meta del x86.INSTRUCTIONS[i] # delete srlq x86.INSTRUCTIONS[i] = x86.INSTRUCTIONS[i].replace(temp, dest) del x86.INSTRUCTIONS[i + 1] # delete permq del x86.INSTRUCTIONS[i + 1] # delete xor else: # if we're keeping the rotation, make it persistent so that the # next rotation is smaller (and thus more likely ignorable) shift_in = shifted offset = delta for m, r in zip(out_data, out): x86.vmovdqa(m, r)
def square_350_701(dst, src): """ Requires source and destination registers to be disjunct. """ r = src r_out = dst maskreg = MaskRegister() lowbitmask = Mask('0' * 255 + '1') x86.vmovdqa(maskreg, lowbitmask) lowbitreg = Register() x86.vpand(lowbitreg, maskreg, r[0]) x86.vpandn(r[0], maskreg, r[0]) rest = Register() twobits = Register() nexttwobits = Register() mask0001 = Mask('0001') x86.vmovdqa(maskreg, mask0001) for i in range(2, -1, -1): x86.vpsllq(rest, r[i], 2) x86.vpsrlq(twobits, r[i], 62) x86.vpermq(twobits, twobits, '10010011') x86.vpand(nexttwobits, maskreg, twobits) x86.vpandn(twobits, maskreg, twobits) x86.vpxor(r[i], rest, twobits) if i + 1 < 3: x86.vpxor(r[i + 1], r[i + 1], nexttwobits) mask_bit_in_byte = [ Mask(32 * ([ZERO] * i + [ONE] + [ZERO] * (7 - i))) for i in range(8) ] bits = Register() accum = Register() for i in range(2, -1, -1): for j in range(8): x86.vpand(bits, r[i], mask_bit_in_byte[j]) if j == 0: x86.vpshlq(accum, bits, 7 - 2 * j) else: x86.vpshlq(bits, bits, 7 - 2 * j) if j == 7: x86.vpxor(r[i], accum, bits) else: x86.vpxor(accum, accum, bits) x86.vpermq(lowbitreg, lowbitreg, '11001111') x86.vpshlq(lowbitreg, lowbitreg, 56) x86.vpxor(r[2], lowbitreg, r[2]) indices = IndicesMask( list(range(15, -1, -1)) + [None] * 8 + list(range(7, -1, -1))) x86.vpshufb(r_out[2], r[0], indices) x86.vpermq(r_out[2], r_out[2], '10010011') t1 = Register() for i in range(2): indices = IndicesMask([None] * 24 + list(range(15, 7, -1))) x86.vpshufb(r_out[1 - i], r[i], indices) indices = IndicesMask( list(range(15, -1, -1)) + list(range(7, -1, -1)) + [None] * 8) x86.vpshufb(t1, r[i + 1], indices) x86.vpxor(r_out[1 - i], t1, r_out[1 - i]) x86.vpermq(r_out[1 - i], r_out[1 - i], '11010010')