def eval(cls, bv): # Source: Hacker's Delight if bv.width == 1: return bv elif bv.width == 2: return operation.RotateLeft(bv, 1) elif bv.width == 3: return operation.Concat(operation.Concat(bv[0], bv[1]), bv[2]) original_width = bv.width while (bv.width & (bv.width - 1)) != 0: bv = operation.ZeroExtend(bv, 1) width_log2 = bv.width.bit_length() - 1 m_ctes = [] for i in range(width_log2): m_ctes.append(repeat_pattern(pattern01(2 ** i), bv.width)) if bv.width > 32: for i, m in list(enumerate(m_ctes)): bv = ((bv & m) << core.Constant(2 ** i, bv.width)) | ((bv >> core.Constant(2 ** i, bv.width)) & m) return bv[:bv.width - original_width] for i, m in list(enumerate(m_ctes))[:3]: bv = ((bv & m) << core.Constant(2 ** i, bv.width)) | ((bv >> core.Constant(2 ** i, bv.width)) & m) # generic case if len(m_ctes) == 4: bv = ((bv & m_ctes[3]) << core.Constant(8, bv.width)) | ((bv >> core.Constant(8, bv.width)) & m_ctes[3]) elif len(m_ctes) == 5: rol = operation.RotateLeft ror = operation.RotateRight bv = ror(bv & m_ctes[3], 8) | (rol(bv, 8) & m_ctes[3]) return bv[:bv.width - original_width]
def pattern01(width): """Obtain the pattern 0...01...1 with given 0-width.""" zeroes = core.Constant(0, width) return operation.Concat(zeroes, ~zeroes)
def bv2pysmt(bv): """Convert a bit-vector type to a pySMT type. >>> from arxpy.bitvector.core import Constant, Variable >>> from arxpy.diffcrypt.smt import bv2pysmt >>> bv2pysmt(Constant(0b00000001, 8)) 1_8 >>> x, y = Variable("x", 8), Variable("y", 8) >>> bv2pysmt(x) x >>> bv2pysmt(x + y) (x + y) >>> bv2pysmt(x <= y) (x u<= y) >>> bv2pysmt(x[4: 2]) x[2:4] """ msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__) if isinstance(bv, int): return bv if isinstance(bv, core.Variable): return sc.Symbol(bv.name, typing.BVType(bv.width)) if isinstance(bv, core.Constant): return sc.BV(bv.val, bv.width) if isinstance(bv, operation.Operation): args = [bv2pysmt(a) for a in bv.args] if type(bv) == operation.BvNot: if args[0].is_equals(): return sc.Not(*args) else: return sc.BVNot(*args) if type(bv) == operation.BvAnd: return sc.BVAnd(*args) if type(bv) == operation.BvOr: return sc.BVOr(*args) if type(bv) == operation.BvXor: return sc.BVXor(*args) if type(bv) == operation.BvComp: # return sc.BVComp(*args) return sc.Equals(*args) if type(bv) == operation.BvUlt: return sc.BVULT(*args) if type(bv) == operation.BvUle: return sc.BVULE(*args) if type(bv) == operation.BvUgt: return sc.BVUGT(*args) if type(bv) == operation.BvUge: return sc.BVUGE(*args) if type(bv) == operation.BvShl: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVLShl(*args) else: x, r = bv.args offset = 0 while (x.width & (x.width - 1)) != 0: x = operation.ZeroExtend(x, 1) r = operation.ZeroExtend(r, 1) offset += 1 shift = bv2pysmt(x << r) return sc.BVExtract(shift, end=shift.bv_width() - offset - 1) # width = args[0].bv_width() # assert (width & (width - 1)) == 0 # power of 2 # return sc.BVLShl(*args) if type(bv) == operation.BvLshr: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVLShr(*args) else: x, r = bv.args offset = 0 while (x.width & (x.width - 1)) != 0: x = operation.ZeroExtend(x, 1) r = operation.ZeroExtend(r, 1) offset += 1 shift = bv2pysmt(x >> r) return sc.BVExtract(shift, end=shift.bv_width() - offset - 1) # width = args[1].bv_width() # assert (width & (width - 1)) == 0 # power of 2 # return sc.BVLShr(*args) if type(bv) == operation.RotateLeft: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVRol(*args) else: x, r = bv.args n = x.width return bv2pysmt(operation.Concat(x[n - r - 1:], x[n - 1:n - r])) if type(bv) == operation.RotateRight: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVRor(*args) else: x, r = bv.args n = x.width return bv2pysmt(operation.Concat(x[r - 1:], x[n - 1:r])) if type(bv) == operation.Ite: if args[0].is_equals(): a0 = args[0] else: a0 = sc.Equals(args[0], bv2pysmt(core.Constant(1, 1))) return sc.Ite(a0, *args[1:]) if type(bv) == operation.Extract: return sc.BVExtract(args[0], args[2], args[1]) if type(bv) == operation.Concat: return sc.BVConcat(*args) if type(bv) == operation.ZeroExtend: return sc.BVZExt(*args) if type(bv) == operation.Repeat: return args[0].BVRepeat(args[1]) if type(bv) == operation.BvNeg: return sc.BVNeg(*args) if type(bv) == operation.BvAdd: return sc.BVAdd(*args) if type(bv) == operation.BvSub: return sc.BVSub(*args) if type(bv) == operation.BvMul: return sc.BVMul(*args) if type(bv) == operation.BvMul: return sc.BVMul(*args) if type(bv) == operation.BvUdiv: return sc.BVUDiv(*args) if type(bv) == operation.BvUrem: return sc.BVURem(*args) raise NotImplementedError(msg)
def bv2pysmt(bv, boolean=False, strict_shift=False, env=None): """Convert a bit-vector type to a pySMT type. Args: bv: the bit-vector `Term` to convert boolean: if True, boolean pySMT types (e.g., `pysmt.shortcuts.Bool`) are used instead of bit-vector pySMT types (e.g., `pysmt.shortcuts.BV`). strict_shift: if `True`, shifts and rotation by non-power-of-two offsets are power of two are translated to pySMT's shifts and rotation directly. env: a `pysmt.environment.Environment`; if not specified, a new pySMT environment is created. :: >>> from arxpy.bitvector.core import Constant, Variable >>> from arxpy.smt.types import bv2pysmt >>> s = bv2pysmt(Constant(0b00000001, 8), boolean=False) >>> s, s.get_type() (1_8, BV{8}) >>> x, y = Variable("x", 8), Variable("y", 8) >>> s = bv2pysmt(x) >>> s, s.get_type() (x, BV{8}) >>> s = bv2pysmt(x + y) >>> s, s.get_type() ((x + y), BV{8}) >>> s = bv2pysmt(x <= y) >>> s, s.get_type() ((x u<= y), Bool) >>> s = bv2pysmt(x[4: 2]) >>> s, s.get_type() (x[2:4], BV{3}) """ msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__) if env is None: env = environment.reset_env() fm = env.formula_manager if isinstance(bv, int): return bv pysmt_bv = None if isinstance(bv, core.Variable): if boolean: assert bv.width == 1 pysmt_bv = fm.Symbol(bv.name, env.type_manager.BOOL()) else: pysmt_bv = fm.Symbol(bv.name, env.type_manager.BVType(bv.width)) elif isinstance(bv, core.Constant): if boolean: assert bv.width == 1 pysmt_bv = fm.Bool(bool(bv)) else: pysmt_bv = fm.BV(bv.val, bv.width) elif isinstance(bv, operation.Operation): # only 1st layer can return a boolean # Equals and Ite work well with BV, the rest don't if issubclass(type(bv), extraop.PartialOperation): raise NotImplementedError("PartialOperation is not yet supported") if type(bv) == operation.BvNot: if boolean: assert bv.width == 1 args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args] pysmt_bv = fm.Not(*args) else: args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args] pysmt_bv = fm.BVNot(*args) elif type(bv) == operation.BvAnd: if boolean: assert bv.width == 1 args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args] pysmt_bv = fm.And(*args) else: args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args] pysmt_bv = fm.BVAnd(*args) elif type(bv) == operation.BvOr: if boolean: assert bv.width == 1 args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args] pysmt_bv = fm.Or(*args) else: args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args] pysmt_bv = fm.BVOr(*args) elif type(bv) == operation.BvXor: if boolean: assert bv.width == 1 args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args] pysmt_bv = fm.Xor(*args) else: args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args] pysmt_bv = fm.BVXor(*args) elif type(bv) == operation.Ite: args = [None for _ in range(len(bv.args))] # fm.Ite requires a Boolean type for args[0] but # bv2pysmt(bv.args[0], True, ...) caused an error # (if args[0] is BvComp, it can be further optimized) args[0] = bv2pysmt(bv.args[0], False, strict_shift, env) if args[0].get_type().is_bv_type(): args[0] = fm.Equals(args[0], fm.BV(1, 1)) if boolean: assert bv.width == 1 args[1:] = [ bv2pysmt(a, True, strict_shift, env) for a in bv.args[1:] ] else: args[1:] = [ bv2pysmt(a, False, strict_shift, env) for a in bv.args[1:] ] pysmt_bv = fm.Ite(*args) else: args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args] if type(bv) == operation.BvComp: if boolean: pysmt_bv = fm.Equals(*args) else: pysmt_bv = fm.BVComp(*args) elif type(bv) == operation.BvUlt: pysmt_bv = fm.BVULT(*args) elif type(bv) == operation.BvUle: pysmt_bv = fm.BVULE(*args) elif type(bv) == operation.BvUgt: pysmt_bv = fm.BVUGT(*args) elif type(bv) == operation.BvUge: pysmt_bv = fm.BVUGE(*args) elif boolean: raise ValueError("{} cannot return a boolean type".format( type(bv).__name__)) elif type(bv) in [operation.BvShl, operation.BvLshr]: if not strict_shift or _is_power_of_2(args[0].bv_width()): if type(bv) == operation.BvShl: pysmt_bv = fm.BVLShl(*args) elif type(bv) == operation.BvLshr: pysmt_bv = fm.BVLShr(*args) else: x, r = bv.args offset = 0 while not _is_power_of_2(x.width): x = operation.ZeroExtend(x, 1) r = operation.ZeroExtend(r, 1) offset += 1 shift = bv2pysmt(type(bv)(x, r), False, strict_shift, env) pysmt_bv = fm.BVExtract(shift, end=shift.bv_width() - offset - 1) elif type(bv) == operation.RotateLeft: if not strict_shift or _is_power_of_2(args[0].bv_width()): pysmt_bv = fm.BVRol(*args) else: # Left hand side width must be a power of 2 x, r = bv.args n = x.width pysmt_bv = bv2pysmt( operation.Concat(x[n - r - 1:], x[n - 1:n - r]), False, strict_shift, env) elif type(bv) == operation.RotateRight: if not strict_shift or _is_power_of_2(args[0].bv_width()): pysmt_bv = fm.BVRor(*args) else: # Left hand side width must be a power of 2 x, r = bv.args n = x.width pysmt_bv = bv2pysmt( operation.Concat(x[r - 1:], x[n - 1:r]), False, strict_shift, env) elif type(bv) == operation.Extract: # pySMT Extract(bv, start, end) pysmt_bv = fm.BVExtract(args[0], args[2], args[1]) elif type(bv) == operation.Concat: pysmt_bv = fm.BVConcat(*args) elif type(bv) == operation.ZeroExtend: pysmt_bv = fm.BVZExt(*args) elif type(bv) == operation.Repeat: pysmt_bv = args[0].BVRepeat(args[1]) elif type(bv) == operation.BvNeg: pysmt_bv = fm.BVNeg(*args) elif type(bv) == operation.BvAdd: pysmt_bv = fm.BVAdd(*args) elif type(bv) == operation.BvSub: pysmt_bv = fm.BVSub(*args) elif type(bv) == operation.BvMul: pysmt_bv = fm.BVMul(*args) elif type(bv) == operation.BvUdiv: pysmt_bv = fm.BVUDiv(*args) elif type(bv) == operation.BvUrem: pysmt_bv = fm.BVURem(*args) else: bv2 = bv.doit() assert bv.width == bv2.width, "{} == {}\n{}\n{}".format( bv.width, bv2.width, bv.vrepr(), bv2.vrepr()) if bv != bv2: # avoid cyclic loop pysmt_bv = bv2pysmt(bv2, boolean=boolean, strict_shift=strict_shift, env=env) else: raise NotImplementedError("(doit) " + msg) elif isinstance(bv, difference.Difference) or isinstance(bv, mask.Mask): pysmt_bv = bv2pysmt(bv.val, boolean, strict_shift, env) if pysmt_bv is not None: try: pysmt_bv_width = pysmt_bv.bv_width() except (AssertionError, TypeError): pysmt_bv_width = 1 # boolean type assert bv.width == pysmt_bv_width return pysmt_bv else: raise NotImplementedError(msg)
def _weight(self, output_diff, prefix=None, debug=False, version=2): u = self.input_diff[0].val v = output_diff.val a = self.op.constant n = a.width one = core.Constant(1, n) assert self._effective_width == n - 1 assert version in [0, 1, 2] # 0-reference, 1-w/o extra reverse, 2-s_000 and no HW2 in fr if prefix is None: prefix = "tmp" + str(abs(hash(u) + hash(v) + hash(a))) if isinstance(u, core.Constant) and isinstance(v, core.Constant): are_cte_differences = True else: self._i_auxvar = 0 assertions = [] are_cte_differences = False def rev(x): if are_cte_differences: return extraop.Reverse(x) else: aux = core.Variable("{}_{}rev".format(prefix, self._i_auxvar), x.width) self._i_auxvar += 1 assertions.append(operation.BvComp(aux, extraop.Reverse(x))) return aux def lz(x): if are_cte_differences: return extraop.LeadingZeros(x) else: aux = core.Variable("{}_{}lz".format(prefix, self._i_auxvar), x.width) self._i_auxvar += 1 assertions.append(operation.BvComp(aux, extraop.LeadingZeros(x))) return aux def carry(x, y): return (x + y) ^ x ^ y def rev_carry(x, y): return rev(carry(rev(x), rev(y))) if version in [0, 1]: s00_old = (~(u << one)) & (~(v << one)) # i-bit is True if S_{i} = 00* else: s00_old = ((~u) & (~v)) << one s00_ = s00_old & (~lz(~s00_old)) # if x is 001*...*, then lz(x) = 1100...0 if version == 0: e_i1 = s00_ & (~ (s00_ >> one)) # e_{i-1} e_ili = ~s00_ & (s00_ >> one) # e_{i-l_i} else: e_i1 = s00_old & (~ (s00_old >> one)) # e_{i-1} e_ili = ~s00_old & (s00_old >> one) # e_{i-l_i} q = ~( (a << one) ^ (u ^ v) ) # q[i] = ~(a[i-1]^u[i]^v[i]) q = ((q >> one ) & e_i1) # q[i-1, i-3] = (a[i-1]^u[i]^v[i], 0, 0) if version == 0: s = ((a << one) & e_ili) + (a & (s00_ >> one)) else: s = ((a << one) & e_ili) + (a & (s00_old >> one)) if version == 0: d = rev_carry(s00_, q) | q else: rev_s00_old = rev(s00_old) d = rev(carry(rev_s00_old, rev(q))) | q w = (q - (s & d)) | (s & (~d)) if version == 0: w = w << one h = rev_carry(s00_ << one, w & (s00_ << one)) elif version == 1: rev_w = rev(w) >> one rev_h = carry( (rev_s00_old + one) >> one, rev_w & (rev(s00_)) >> one) else: rev_w = rev(w) rev_h = carry(rev_s00_old + one, rev_w & rev_s00_old) sbnegb = (u ^ v) << one # i-bit is True if S_{i} = (b, \neg b, *) if version == 0: int = extraop.PopCountDiff(sbnegb | s00_, h) # or hw(sbminb_) + (hw(s00_) - hw(h)) else: int = extraop.PopCountDiff(sbnegb | s00_, rev_h) def smart_add(x, y): if x.width == y.width: return x + y elif x.width < y.width: return operation.ZeroExtend(x, y.width - x.width) + y else: return x + operation.ZeroExtend(y, x.width - y.width) def smart_sub(x, y): # cannot be replaced by smart_add(x, -y) if x.width == y.width: return x - y elif x.width < y.width: return operation.ZeroExtend(x, y.width - x.width) - y else: return x - operation.ZeroExtend(y, x.width - y.width) k = self._effective_precision if k == 0: int_frac = int elif k == 1: int = operation.Concat(int, core.Constant(0, 1)) if version == 0: f1 = extraop.PopCount(w & h & (~(h >> one))) # each one adds 2^(-1) else: f1 = extraop.PopCount(rev_w & rev_h & (~(rev_h << one))) int_frac = smart_sub(int, f1) else: two = core.Constant(2, n) three = core.Constant(3, n) four = core.Constant(4, n) if version == 0: f12 = extraop.PopCountSum2( w & h & (~(h >> one)), w & h & ((~(h >> one)) | (~(h >> two)) & (h >> one)) ) # each one adds 2^(-2), that's why ~(h >> one) need to be counted twice elif version == 1: f12 = extraop.PopCountSum2( rev_w & rev_h & (~(rev_h << one)), rev_w & rev_h & ((~(rev_h << one)) | (~(rev_h << two)) & (rev_h << one)) ) else: f12 = extraop.PopCount( # ( ( rev_w & rev_h & (~(rev_h << one)) ) >> one ) | ( ( (rev_w & rev_h) >> one) & (~rev_h) ) | (rev_w & rev_h & ((~(rev_h << one)) | (~(rev_h << two)) & (rev_h << one))) ) if k == 2: int = operation.Concat(int, core.Constant(0, 2)) int_frac = smart_sub(int, f12) elif k == 3: # f3 cannot be included in f12, since ~(h >> one) would need to be counted 4 times if version == 0: f3 = extraop.PopCount(w & h & (h >> one) & (h >> two) & (~(h >> three))) else: f3 = extraop.PopCount(rev_w & rev_h & (rev_h << one) & (rev_h << two) & (~(rev_h << three))) int = operation.Concat(int, core.Constant(0, 3)) f12 = operation.Concat(f12, core.Constant(0, 1)) int_frac = smart_sub(int, smart_add(f12, f3)) elif k == 4: if version == 0: f34 = extraop.PopCountSum2( w & h & (h >> one) & (h >> two) & (~(h >> three)), w & h & (h >> one) & (h >> two) & ((~(h >> three)) | (~(h >> four) & (h >> three))) ) elif version == 1: f34 = extraop.PopCountSum2( rev_w & rev_h & (rev_h << one) & (rev_h << two) & (~(rev_h << three)), rev_w & rev_h & (rev_h << one) & (rev_h << two) & ((~(rev_h << three)) | (~(rev_h << four) & (rev_h << three))) ) else: f34 = extraop.PopCount( # ( (rev_w & rev_h & (rev_h << one) & (rev_h << two) & (~(rev_h << three))) >> one ) | ( ((rev_w & rev_h) >> one) & rev_h & (rev_h << one) & (~(rev_h << two))) | (rev_w & rev_h & (rev_h << one) & (rev_h << two) & ((~(rev_h << three)) | (~(rev_h << four) & (rev_h << three)))) ) int = operation.Concat(int, core.Constant(0, 4)) f12 = operation.Concat(f12, core.Constant(0, 2)) int_frac = smart_sub(int, smart_add(f12, f34)) else: raise ValueError("precision must be between 0 and 4") if debug: print("\n\n ~~ ") print("u: ", u.bin()) print("v: ", v.bin()) print("a: ", a.bin()) print("s00_: ", s00_.bin()) print("e_i1: ", e_i1.bin()) print("e_ili1: ", e_ili.bin()) print("q: ", q.bin()) print("s: ", s.bin()) print("d: ", d.bin()) print("w: ", w.bin()) if version == 0: print("h: ", h.bin()) else: print("rev_w: ", rev_w.bin()) print("rev_h: ", rev_h.bin()) print("sbnegb: ", sbnegb.bin()) print("int: ", int.bin()) if k == 1: print("f1: ", f1.bin()) elif k > 1: print("f12: ", f12.bin()) if k == 3: print("f3: ", f3.bin()) elif k == 4: print("f34: ", f34.bin()) print("int_frac: ", int_frac.bin()) if are_cte_differences: return int_frac else: return int_frac, assertions