def update_flag_sub_of(op1, op2): "Compote OF in @op1 - @op2" return [ExprAff(of, ExprOp("FLAG_SUB_OF", op1, op2))]
def update_flag_zfaddwc_eq(arg1, arg2, arg3): return [ExprAff(zf, ExprOp("FLAG_EQ_ADDWC", arg1, arg2, arg3))]
def update_flag_add_cf(op1, op2): "Compute cf in @op1 + @op2" return [ExprAff(cf, ExprOp("FLAG_ADD_CF", op1, op2))]
def update_flag_add_of(op1, op2): "Compute of in @op1 + @op2" return [ExprAff(of, ExprOp("FLAG_ADD_OF", op1, op2))]
def update_flag_subwc_cf(op1, op2, op3): "Compute cf in @res = @op1 + @op2 + @op3" return [ ExprAff(cf, ExprOp("FLAG_SUBWC_CF", op1, op2, op3) ^ ExprInt(1, 1)) ]
def update_flag_zf_eq(a, b): return [ExprAff(zf, ExprOp("FLAG_EQ_CMP", a, b))]
def simp_cc_conds(expr_simp, expr): if (expr.is_op("CC_U>=") and test_cc_eq_args(expr, "FLAG_SUB_CF")): expr = ExprCond(ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_U<") and test_cc_eq_args(expr, "FLAG_SUB_CF")): expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_NEG") and test_cc_eq_args(expr, "FLAG_SIGN_SUB")): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_POS") and test_cc_eq_args(expr, "FLAG_SIGN_SUB")): expr = ExprCond(ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_EQ") and test_cc_eq_args(expr, "FLAG_EQ")): arg = expr.args[0].args[0] expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)) elif (expr.is_op("CC_NE") and test_cc_eq_args(expr, "FLAG_EQ")): arg = expr.args[0].args[0] expr = ExprCond(ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_NE") and test_cc_eq_args(expr, "FLAG_EQ_CMP")): expr = ExprCond(ExprOp(TOK_EQUAL, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_EQ") and test_cc_eq_args(expr, "FLAG_EQ_CMP")): expr = ExprOp(TOK_EQUAL, *expr.args[0].args) elif (expr.is_op("CC_NE") and test_cc_eq_args(expr, "FLAG_EQ_AND")): expr = ExprOp("&", *expr.args[0].args) elif (expr.is_op("CC_EQ") and test_cc_eq_args(expr, "FLAG_EQ_AND")): expr = ExprCond(ExprOp("&", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S>") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprCond(ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S>") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprCond(ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S>=") and test_cc_eq_args(expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF")): expr = ExprCond(ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S<") and test_cc_eq_args(expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF")): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_U<=") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_U>") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprCond(ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S<") and test_cc_eq_args(expr, "FLAG_SIGN_ADD", "FLAG_ADD_OF")): arg0, arg1 = expr.args[0].args expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1) return expr
def update_flag_zf(a): return [ExprAff(zf, ExprOp("FLAG_EQ", a))]
def clz(ir, instr, arg1, arg2): e = [] e.append(ExprAssign(arg1, ExprOp('cntleadzeros', arg2))) return e, []
def update_flag_addwc_of(op1, op2, op3): "Compute of in @res = @op1 + @op2 + @op3" return [ExprAssign(of, ExprOp("FLAG_ADDWC_OF", op1, op2, op3))]
def fcvt(ir, instr, arg1, arg2): # XXX TODO: rounding e = [] src = ExprOp('fpconvert_fp%d' % arg1.size, arg2) e.append(ExprAssign(arg1, src)) return e, []
def fmsub(ir, instr, arg1, arg2, arg3, arg4): e = [] e.append(ExprAssign(arg1, ExprOp('fsub', arg4, ExprOp('fmul', arg2, arg3)))) return e, []
def fabs(ir, instr, arg1, arg2): e = [] e.append(ExprAssign(arg1, ExprOp('fabs', arg2))) return e, []
def fdiv(ir, instr, arg1, arg2, arg3): e = [] e.append(ExprAssign(arg1, ExprOp('fdiv', arg2, arg3))) return e, []
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependent if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out =int1.arg ** int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': if int(int2) > int1.size: out = 0 else: out = int1.arg >> int2.arg elif op_name == '<<': if int(int2) > int1.size: out = 0 else: out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: out = -1 else: out = 0 else: out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': shifter = int2.arg % int2.size out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) elif op_name == '<<<': shifter = int2.arg % int2.size out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'sdiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'smod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): if args[0].arg == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = - args[-1] # op A => A if op_name in ['+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'sdiv', 'smod', 'umod', 'udiv'] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[-1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): A = args[0].args[0] X = args[0].args[1] Y = args[1] if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): return args[0].args[0] elif X.is_int() and Y.is_int(): new_X = int(X) % expr.size new_Y = int(Y) % expr.size if op_name == args[0].op: rot = (new_X + new_Y) % expr.size op = op_name else: rot = new_Y - new_X op = op_name if rot < 0: rot = - rot op = {">>>": "<<<", "<<<": ">>>"}[op_name] args = [A, ExprInt(rot, expr.size)] op_name = op else: # Do not consider this case, too tricky (overflow on addition / # subtraction) pass # A >> X >> Y => A >> (X+Y) if X + Y does not overflow # To be sure, only consider the simplification when X.msb and Y.msb are 0 if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): X = args[0].args[1] Y = args[1] if (e_s(X.msb()) == ExprInt(0, 1) and e_s(Y.msb()) == ExprInt(0, 1)): args = [args[0].args[0], X + Y] # ((var >> int1) << int1) => var & mask # ((var << int1) >> int1) => var & mask if (op_name in ['<<', '>>'] and args[0].is_op() and args[0].op in ['<<', '>>'] and op_name != args[0]): var = args[0].args[0] int1 = args[0].args[1] int2 = args[1] if int1 == int2 and int1.is_int() and int(int1) < expr.size: if op_name == '>>': mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) else: mask = ExprInt( ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), expr.size ) ret = var & mask return ret # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) with mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2 ** args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # -(a * b * int) => a * b * (-int) if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int(): args = args[0].args return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) return ExprOp(op_name, *args)
def update_flag_subwc_of(op1, op2, op3): "Compute of in @res = @op1 + @op2 + @op3" return [ExprAff(of, ExprOp("FLAG_SUBWC_OF", op1, op2, op3))] def update_flag_arith_subwc_co(arg1, arg2, arg3): e = [] e += update_flag_subwc_cf(arg1, arg2, arg3) e += update_flag_subwc_of(arg1, arg2, arg3) return e cond2expr = { 'EQ': ExprOp("CC_EQ", zf), 'NE': ExprOp("CC_NE", zf), 'CS': ExprOp("CC_U>=", cf ^ ExprInt(1, 1)), # inv cf 'CC': ExprOp("CC_U<", cf ^ ExprInt(1, 1)), # inv cf 'MI': ExprOp("CC_NEG", nf), 'PL': ExprOp("CC_POS", nf), 'VS': ExprOp("CC_sOVR", of), 'VC': ExprOp("CC_sNOOVR", of), 'HI': ExprOp("CC_U>", cf ^ ExprInt(1, 1), zf), # inv cf 'LS': ExprOp("CC_U<=", cf ^ ExprInt(1, 1), zf), # inv cf 'GE': ExprOp("CC_S>=", nf, of), 'LT': ExprOp("CC_S<", nf, of), 'GT': ExprOp("CC_S>", nf, of, zf), 'LE': ExprOp("CC_S<=", nf, of, zf), 'AL': ExprInt(1, 1), 'NV': ExprInt(0, 1)
def simp_add_multiple(_, expr): """ X + X => 2 * X X + X * int1 => X * (1 + int1) X * int1 + (- X) => X * (int1 - 1) X + (X << int1) => X * (1 + 2 ** int1) Correct even if addition overflow/underflow """ if not expr.is_op('+'): return expr # Extract each argument and its counter operands = {} for arg in expr.args: if arg.is_op('*') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) + int(factor) elif arg.is_op('<<') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor) elif arg.is_op("-"): arg = arg.args[0] if arg.is_op('<<') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor)) else: operands[arg] = operands.get(arg, 0) - 1 else: operands[arg] = operands.get(arg, 0) + 1 out = [] # Best effort to factor common args: # (a + b) * 3 + a + b => (a + b) * 4 # Does not factor: # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a modified = True while modified: modified = False for arg, count in operands.iteritems(): if not arg.is_op('+'): continue components = arg.args if not all(component in operands for component in components): continue counters = set(operands[component] for component in components) if len(counters) != 1: continue counter = counters.pop() for component in components: del operands[component] operands[arg] += counter modified = True break for arg, count in operands.iteritems(): if count == 0: continue if count == 1: out.append(arg) continue out.append(arg * ExprInt(count, expr.size)) if len(out) == len(expr.args): # No reductions return expr if not out: return ExprInt(0, expr.size) if len(out) == 1: return out[0] return ExprOp('+', *out)
def update_flag_addwc_cf(op1, op2, op3): "Compute cf in @res = @op1 + @op2 + @op3" return [ExprAff(cf, ExprOp("FLAG_ADDWC_CF", op1, op2, op3))]
def simp_cc_conds(_, expr): """ High level simplifications. Example: CC_U<(FLAG_SUB_CF(A, B) => A <u B """ if (expr.is_op("CC_U>=") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprCond( ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size)) elif (expr.is_op("CC_U<") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_NEG") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_POS") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprCond( ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprCond( ExprOp(TOK_EQUAL, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprOp(TOK_EQUAL, *expr.args[0].args) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprOp("&", *expr.args[0].args) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprCond( ExprOp("&", *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_U<=") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_U>") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_ADD", "FLAG_ADD_OF" )): arg0, arg1 = expr.args[0].args expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1) return expr
def update_flag_subwc_of(op1, op2, op3): "Compute of in @res = @op1 + @op2 + @op3" return [ExprAff(of, ExprOp("FLAG_SUBWC_OF", op1, op2, op3))]
def simp_cond_flag(_, expr): """FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B""" cond = expr.cond if cond.is_op("FLAG_EQ_CMP"): return ExprCond(ExprOp(TOK_EQUAL, *cond.args), expr.src1, expr.src2) return expr
def update_flag_nf(arg): return [ExprAff(nf, ExprOp("FLAG_SIGN_SUB", arg, ExprInt(0, arg.size)))]
def simp_slice(e_s, expr): "Slice optimization" # slice(A, 0, a.size) => A if expr.start == 0 and expr.stop == expr.arg.size: return expr.arg # Slice(int) => int if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: raise ValueError('slice in slice: getting more val', str(expr)) return ExprSlice( expr.arg.arg, expr.start + expr.arg.start, expr.start + expr.arg.start + (expr.stop - expr.start)) if expr.arg.is_compose(): # Slice(Compose(A), x) => Slice(A, y) for index, arg in expr.arg.iter_args(): if index <= expr.start and index + arg.size >= expr.stop: return arg[expr.start - index:expr.stop - index] # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C out = [] for index, arg in expr.arg.iter_args(): # arg is before slice start if expr.start >= index + arg.size: continue # arg is after slice stop elif expr.stop <= index: continue # arg is fully included in slice elif expr.start <= index and index + arg.size <= expr.stop: out.append(arg) continue # arg is truncated at start if expr.start > index: slice_start = expr.start - index else: # arg is not truncated at start slice_start = 0 # a is truncated at stop if expr.stop < index + arg.size: slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start else: slice_stop = arg.size out.append(arg[slice_start:slice_stop]) return ExprCompose(*out) # ExprMem(x, size)[:A] => ExprMem(x, a) # XXXX todo hum, is it safe? if (expr.arg.is_mem() and expr.start == 0 and expr.arg.size > expr.stop and expr.stop % 8 == 0): return ExprMem(expr.arg.arg, size=expr.stop) # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) if expr.arg.is_cond() and expr.arg.src1.is_int() and expr.arg.src2.is_int( ): src1 = expr.arg.src1[expr.start:expr.stop] src2 = expr.arg.src2[expr.start:expr.stop] return ExprCond(expr.arg.cond, src1, src2) # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): args = [ e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args ] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and expr.arg.args[1].is_int()): arg, shift = expr.arg.args shift = int(shift) if expr.arg.op == ">>": if shift + expr.stop <= arg.size: return arg[expr.start + shift:expr.stop + shift] elif expr.arg.op == "<<": if expr.start - shift >= 0: return arg[expr.start - shift:expr.stop - shift] else: raise ValueError('Bad case') return expr
def asr(arg1, arg2, arg3): arg1 = ExprOp('a>>', arg2, (arg3 & ExprInt(arg3.size - 1, arg3.size)))
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependant if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out = int1.arg**int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': out = int1.arg >> int2.arg elif op_name == '<<': out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': out = (int1.arg >> (int2.arg % int2.size) | int1.arg << ((int1.size - int2.arg) % int2.size)) elif op_name == '<<<': out = (int1.arg << (int2.arg % int2.size) | int1.arg >> ((int1.size - int2.arg) % int2.size)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'idiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'imod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # bsf(int) => int if op_name == "bsf" and args[0].is_int() and args[0].arg != 0: i = 0 while args[0].arg & (1 << i) == 0: i += 1 return ExprInt(i, args[0].size) # bsr(int) => int if op_name == "bsr" and args[0].is_int() and args[0].arg != 0: i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(i, args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len( args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = -args[-1] # op A => A if op_name in [ '+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv' ] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>) if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): op1 = op_name op2 = args[0].op if op1 == op2: op_name = op1 args1 = args[0].args[1] + args[1] else: op_name = op2 args1 = args[0].args[1] - args[1] args0 = args[0].args[0] args = [args0, args1] # A >> X >> Y => A >> (X+Y) if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): args = [args[0].args[0], args[0].args[1] + args[1]] # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) whith mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2**args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index + shift, index + shift + arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] stop = final_size filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index - shift, index + arg.size - shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] start = 0 filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: bound = list(bounds)[0] new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) # <<<c_rez, >>>c_rez if op_name in [">>>c_rez", "<<<c_rez"]: assert len(args) == 3 dest, rounds, carry_flag = args # Skipped if rounds is 0 if rounds.is_int(0): return dest elif all(arg.is_int() for arg in args): # The expression can be resolved tmp = int(dest) carry_flag = int(carry_flag) size = dest.size tmp_count = (int(rounds) & (0x3f if size == 64 else 0x1f)) % (size + 1) if op_name == ">>>c_rez": while tmp_count != 0: tmp_cf = tmp & 1 tmp = (tmp >> 1) + (carry_flag << (size - 1)) carry_flag = tmp_cf tmp_count -= 1 tmp &= int(dest.mask) elif op_name == "<<<c_rez": while tmp_count != 0: tmp_cf = (tmp >> (size - 1)) & 1 tmp = (tmp << 1) + carry_flag carry_flag = tmp_cf tmp_count -= 1 tmp &= int(dest.mask) else: raise RuntimeError("Unknown operation: %s" % op_name) return ExprInt(tmp, size=dest.size) return ExprOp(op_name, *args)
def update_flag_sub_cf(op1, op2): "Compote CF in @op1 - @op2" return [ExprAff(cf, ExprOp("FLAG_SUB_CF", op1, op2) ^ ExprInt(1, 1))]
def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0): """ Deep First evaluate nodes: 1. evaluate node's sons 2. simplify """ if expr in cache: ret = cache[expr] elif expr in state: return state[expr] elif expr.is_int(): ret = expr elif expr.is_id(): if isinstance(expr.name, asmblock.asm_label) and expr.name.offset is not None: ret = ExprInt(expr.name.offset, expr.size) elif expr in self.regstop: ret = exprid_top(expr) else: ret = state.get(expr, expr) elif expr.is_mem(): ret = self.manage_mem(expr, state, cache, level) elif expr.is_cond(): cond = self.apply_expr_on_state_visit_cache( expr.cond, state, cache, level + 1) src1 = self.apply_expr_on_state_visit_cache( expr.src1, state, cache, level + 1) src2 = self.apply_expr_on_state_visit_cache( expr.src2, state, cache, level + 1) if cond.is_id(TOPSTR) or src1.is_id(TOPSTR) or src2.is_id(TOPSTR): ret = exprid_top(expr) else: ret = ExprCond(cond, src1, src2) elif expr.is_slice(): arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level + 1) if arg.is_id(TOPSTR): ret = exprid_top(expr) else: ret = ExprSlice(arg, expr.start, expr.stop) elif expr.is_op(): args = [] for oarg in expr.args: arg = self.apply_expr_on_state_visit_cache( oarg, state, cache, level + 1) assert oarg.size == arg.size if arg.is_id(TOPSTR): return exprid_top(expr) args.append(arg) ret = ExprOp(expr.op, *args) elif expr.is_compose(): args = [] for arg in expr.args: arg = self.apply_expr_on_state_visit_cache( arg, state, cache, level + 1) if arg.is_id(TOPSTR): return exprid_top(expr) args.append(arg) ret = ExprCompose(*args) else: raise TypeError("Unknown expr type") ret = self.expr_simp(ret) assert expr.size == ret.size cache[expr] = ret return ret
def udiv(arg1, arg2, arg3): if arg3: arg1 = ExprOp('udiv', arg2, arg3) else: exception_flags = ExprInt(EXCEPT_DIV_BY_ZERO, exception_flags.size)
def test_ExprOp_toC(self): from miasm2.expression.expression import ExprInt32, ExprOp from miasm2.ir.translators.C import Translator args = [ExprInt32(i) for i in xrange(9)] translator = Translator.to_language("C") # Unary operators self.translationTest(ExprOp('parity', *args[:1]), r'parity(0x0&0xffffffff)') self.translationTest(ExprOp('!', *args[:1]), r'(~ 0x0)&0xffffffff') self.translationTest(ExprOp('hex2bcd', *args[:1]), r'hex2bcd_32(0x0)') self.translationTest(ExprOp('fabs', *args[:1]), r'fabs(0x0)') self.assertRaises(NotImplementedError, translator.from_expr, ExprOp('X', *args[:1])) # Binary operators self.translationTest(ExprOp('==', *args[:2]), r'(((0x0&0xffffffff) == (0x1&0xffffffff))?1:0)') self.translationTest( ExprOp('%', *args[:2]), r'(((0x0&0xffffffff)%(0x1&0xffffffff))&0xffffffff)') self.translationTest( ExprOp('-', *args[:2]), r'(((0x0&0xffffffff) - (0x1&0xffffffff))&0xffffffff)') self.translationTest(ExprOp('bsr', *args[:1]), r'x86_bsr(0x0, 0x20)') self.translationTest(ExprOp('cpuid0', *args[:2]), r'cpuid0(0x0, 0x1)') self.translationTest(ExprOp('fcom0', *args[:2]), r'fcom0(0x0, 0x1)') self.translationTest(ExprOp('fadd', *args[:2]), r'fadd(0x0, 0x1)') self.translationTest(ExprOp('segm', *args[:2]), r'segm2addr(jitcpu, 0x0, 0x1)') self.translationTest(ExprOp('imod', *args[:2]), r'imod32((vm_cpu_t*)jitcpu->cpu, 0x0, 0x1)') self.translationTest(ExprOp('bcdadd', *args[:2]), r'bcdadd_32(0x0, 0x1)') self.assertRaises(NotImplementedError, translator.from_expr, ExprOp('X', *args[:2])) # Other cases self.translationTest( ExprOp('+', *args[:3]), r'(((0x0&0xffffffff)+(0x1&0xffffffff)+(0x2&0xffffffff))&0xffffffff)' ) self.assertRaises(NotImplementedError, translator.from_expr, ExprOp('X', *args[:3]))
def update_flag_zfsubwc_eq(arg1, arg2, arg3): return [ExprAff(zf, ExprOp("FLAG_EQ_SUBWC", arg1, arg2, arg3))]
def bclrm(rm_deref, imm3): """BCLRM - Bit Clear Memory""" # MemByte(Rm) <- MemByte(Rm) and ~(1<<imm3) shift = ExprOp("<<", i8(1), imm3[:8]) mem8[rm_deref.ptr] = ExprOp("&", mem8[rm_deref.ptr], shift.__invert__())