def getPathConditions(trace, debug = False): # Initialization inss = trace["code"] callstack = trace["callstack"] memory = trace["mem_access"] parameters = trace["func_parameters"] # we reverse the code order inss.reverse() #print inss[0] # we reset the used memory variables Memvars.reset() # we save the current callstack last_index = callstack.index # TODO: create a better interface # ssa and smt objects ssa = SSA() smt_conds = SMT() mvars = set() mlocs = set() for op in trace["final_conditions"]: mvars.add(op) mlocs = mlocs.union(op.getLocations()) # we start without free variables fvars = set() ssa.getMap(mvars, set(), set()) setInitialConditions(ssa, trace["final_conditions"],smt_conds) #for c in smt_conds: # print c #assert(0) for ins in inss: counter = ins.getCounter() func_cons = funcs.get(ins.called_function, Function) if memory.getAccess(counter) <> None: ins.setMemoryAccess(memory.getAccess(counter)) ins.clearMemRegs() func = func_cons(None, parameters.getParameters(counter)) if debug: print "(%.4d)" % counter, ins for v in mvars: print v, v.getSizeInBytes(), "--", print "" for l in mlocs: print l, "--", print "" ins_write_vars = set(ins.getWriteVarOperands()) ins_read_vars = set(ins.getReadVarOperands()) func_write_vars = set(func.getWriteVarOperands()) func_read_vars = set(func.getReadVarOperands()) ins_write_locs = concatSet(map(lambda op: set(op.getLocations()), ins.getWriteVarOperands())) ins_read_locs = concatSet(map(lambda op: set(op.getLocations()), ins.getReadVarOperands())) func_write_locs = concatSet(map(lambda op: set(op.getLocations()), func.getWriteVarOperands())) func_read_locs = concatSet(map(lambda op: set(op.getLocations()), func.getReadVarOperands())) #if (func_write_vars <> set()): # x = func_write_vars.pop() # print x, x.getLocations() # assert(0) #print func, parameters.getParameters(counter), func_write_vars, func_write_locs if (not ins.isCall()) and (ins.isJmp() or ins.isCJmp() or len(ins_write_locs.intersection(mlocs)) > 0): ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars)) cons = conds.get(ins.instruction, Condition) condition = cons(ins, ssa_map) mlocs = mlocs.difference(ins_write_locs) mlocs = ins_read_locs.union(mlocs) mvars = mvars.difference(ins_write_vars) mvars = ins_read_vars.union(mvars) smt_conds.add(condition.getEq()) elif (len(func_write_locs.intersection(mlocs)) > 0): # TODO: clean-up here! #ssa_map = ssa.getMap(func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars)) cons = conds.get(ins.called_function, Condition) condition = cons(func, None) c = condition.getEq(func_write_locs.intersection(mlocs)) mlocs = mlocs.difference(func_write_locs) mlocs = func_read_locs.union(mlocs) mvars = mvars.difference(func_write_vars) mvars = func_read_vars.union(mvars) smt_conds.add(c) #print c #assert(0) # additional conditions #mvars = addAditionalConditions(mvars, mlocs, ins, ssa, callstack, smt_conds) # we update the current call for next instruction callstack.prevInstruction(ins) fvars = set() ssa_map = ssa.getMap(set(), set(), mvars) for var in mvars: #print v, "--", #if not (v in initial_values): print "#Warning__", str(var), "is free!" if (var |iss| InputOp): fvars.add(var) elif var |iss| MemOp: f_op = var.copy() f_op.name = Memvars.read(var) fvars.add(f_op) else: f_op = var.copy() f_op.name = f_op.name+"_0" fvars.add(f_op) #else: #fvars.add(ssa_map[str(var)]) # perform SSA #assert(0) #setInitialConditions(ssa, initial_values, smt_conds) #smt_conds.solve(debug) callstack.index = last_index # TODO: create a better interface smt_conds.write_smtlib_file("exp.smt2") smt_conds.write_sol_file("exp.sol") smt_conds.solve(debug) if (smt_conds.is_sat()): #smt_conds.solve(debug) return (fvars, Solution(smt_conds.m)) else: # unsat :( return (set(), None)
def getPathConditions(trace): inss = trace["code"] callstack = trace["callstack"] initial_values = trace["initial_conditions"] final_values = trace["final_conditions"] memory = trace["mem_access"] parameters = trace["func_parameters"] # we reverse the code order inss.reverse() # we reset the used memory variables Memvars.reset() # we set the instruction counter counter = len(inss)-1 # ssa and smt objects ssa = SSA() smt_conds = SMT() # auxiliary eq condition eq = Eq(None, None) mvars = set() # final conditions: for (op, _) in final_values.items(): mvars.add(op) ssa.getMap(mvars, set(), set()) setInitialConditions(ssa, final_values, smt_conds) # we start without free variables fvars = set() for ins_str in inss: #print ins_str.strip("\n") #for v in mvars: # print v, # pins = parse_reil(ins_str) ins = Instruction(pins, memory.getAccess(counter), mem_regs = False) ins_write_vars = set(ins.getWriteVarOperands()) ins_read_vars = set(ins.getReadVarOperands()) if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0: ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars)) cons = conds.get(pins.instruction, Condition) condition = cons(ins, ssa_map) mvars = mvars.difference(ins_write_vars) mvars = ins_read_vars.union(mvars) smt_conds.add(condition.getEq()) elif (ins.instruction == "call" and ins.called_function <> None): func_cons = funcs.get(ins.called_function, Function) func = func_cons(None, parameters.getParameters(counter)) func_write_vars = set(func.getWriteVarOperands()) func_read_vars = set(func.getReadVarOperands()) #for op in func_write_vars: # print op if len(func_write_vars.intersection(mvars)) > 0: ssa_map = ssa.getMap(func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars)) cons = conds.get(ins.called_function, Condition) condition = cons(func, None) c = condition.getEq(func_write_vars.intersection(mvars)) mvars = mvars.difference(func_write_vars) mvars = func_read_vars.union(mvars) smt_conds.add(c) # additional conditions mvars = addAditionalConditions(mvars, ins, ssa, callstack, smt_conds) # no more things to do # we update the counter counter = counter - 1 # we update the current call for next instruction callstack.prevInstruction(ins_str) #for v in mvars: # print v fvars = filter(lambda v: not (v in initial_values.keys()), mvars) for v in fvars: # print v,n if not (v in initial_values) and not (":" in v.name): print "#Warning", str(v), "is free!" setInitialConditions(ssa, initial_values, smt_conds) if (smt_conds.is_sat()): smt_conds.solve() smt_conds.write_smtlib_file("exp.smt2") smt_conds.write_sol_file("exp.sol") return Solution(smt_conds.m, fvars) else: # unsat :( return None
def getPathConditions(trace): inss = trace["code"] callstack = trace["callstack"] initial_values = trace["initial_conditions"] final_values = trace["final_conditions"] memory = trace["mem_access"] parameters = trace["func_parameters"] # we reverse the code order inss.reverse() # we reset the used memory variables Memvars.reset() # we set the instruction counter counter = len(inss) - 1 # ssa and smt objects ssa = SSA() smt_conds = SMT() # auxiliary eq condition eq = Eq(None, None) mvars = set() # final conditions: for (op, _) in final_values.items(): mvars.add(op) ssa.getMap(mvars, set(), set()) setInitialConditions(ssa, final_values, smt_conds) # we start without free variables fvars = set() for ins in inss: if memory.getAccess(counter) <> None: ins.fixMemoryAccess(memory.getAccess(counter)) ins_write_vars = set(ins.getWriteVarOperands()) ins_read_vars = set(ins.getReadVarOperands()) if ins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0: ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars)) cons = conds.get(ins.instruction, Condition) condition = cons(ins, ssa_map) mvars = mvars.difference(ins_write_vars) mvars = ins_read_vars.union(mvars) smt_conds.add(condition.getEq()) elif ins.isCall() and ins.called_function <> None: func_cons = funcs.get(ins.called_function, Function) func = func_cons(None, parameters.getParameters(counter)) func_write_vars = set(func.getWriteVarOperands()) func_read_vars = set(func.getReadVarOperands()) if len(func_write_vars.intersection(mvars)) > 0: ssa_map = ssa.getMap( func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars) ) cons = conds.get(ins.called_function, Condition) condition = cons(func, None) c = condition.getEq(func_write_vars.intersection(mvars)) mvars = mvars.difference(func_write_vars) mvars = func_read_vars.union(mvars) smt_conds.add(c) # additional conditions mvars = addAditionalConditions(mvars, ins, ssa, callstack, smt_conds) # no more things to do # we update the counter counter = counter - 1 # we update the current call for next instruction callstack.prevInstruction(ins) # for v in mvars: # print v fvars = filter(lambda v: not (v in initial_values.keys()), mvars) for v in fvars: # print v,n if not (v in initial_values) and not (":" in v.name): print "#Warning", str(v), "is free!" setInitialConditions(ssa, initial_values, smt_conds) if smt_conds.is_sat(): smt_conds.solve() smt_conds.write_smtlib_file("exp.smt2") smt_conds.write_sol_file("exp.sol") return Solution(smt_conds.m, fvars) else: # unsat :( return None
def getValueFromCode(reil_code, callstack, memory, addr_op, addr, val_op, val): assert(reil_code <> []) free_variables = [] # code should be copied and reversed inss = list(reil_code) inss.reverse() # counter is set counter = len(reil_code) tracked_stack_frame = callstack.index # especial operands in a call ssa = SSA() smt_conds = SMT() # we will track op mvars = set([addr_op, val_op]) ssa_map = ssa.getMap(mvars, set(), set()) eq = Eq(None, None) addr = Operand(str(addr), "DWORD") val = Operand(str(val), "BYTE") val.size = val_op.size smt_conds.add(eq.getEq(ssa_map[addr_op.name],addr)) smt_conds.add(eq.getEq(ssa_map[val_op.name],val)) for ins_str in inss: #print ins_str.strip("\n") pins = parse_reil(ins_str) ins = Instruction(pins, memory.getAccess(counter), mem_regs = False) ins_write_vars = set(ins.getWriteVarOperands()) ins_read_vars = set(ins.getReadVarOperands()) if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0: #if len(ins_write_vars.intersection(mvars)) > 0: ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars)) cons = conds.get(pins.instruction, Condition) condition = cons(ins, ssa_map) mvars = mvars.difference(ins_write_vars) mvars = ins_read_vars.union(mvars) smt_conds.add(condition.getEq()) counter = counter - 1 if len(mvars) > 0: tracked_stack_frame = callstack.index if pins.instruction == "call": if callstack.index == 1: esp_val = 4 else: esp_val = 8 ebp_val = 0 esp_op = Operand("esp","DWORD") ebp_op = Operand("ebp","DWORD") initial_values_at_call = dict() initial_values_at_call[esp_op] = Operand(str(esp_val), "DWORD") initial_values_at_call[ebp_op] = Operand(str(ebp_val), "DWORD") for iop in initial_values_at_call.keys(): if not (iop in mvars): del initial_values_at_call[iop] ssa_map = ssa.getMap(set(), set(), set(initial_values_at_call.keys())) eq = Eq(None, None) for iop in initial_values_at_call: smt_conds.add(eq.getEq(ssa_map[iop.name],initial_values_at_call[iop])) mvars = set(filter(lambda o: not (o in initial_values_at_call.keys()), mvars)) if (counter == 0 and len(mvars)>0): #cond = Initial_Cond(None, None) # #for v in mvars: # print str(v), # smt_conds.add(cond.getEq(v)) #print "are free" #print smt_conds.solver free_variables = mvars break new_mvars = set() for v in mvars: if v.isMem(): # this should work for stack memory eop = callstack.convertStackMemOp(v) #print eop smt_conds.add(eq.getEq(v,eop)) new_mvars.add(eop) mvars = set(filter(lambda o: not (o.isMem()), mvars)) mvars = mvars.union(new_mvars) # we update the current call for next instruction callstack.prevInstruction(ins_str) #op.name = op.name+"_0" smt_conds.solve() smt_conds.write_smtlib_file("exp.smt2") smt_conds.write_sol_file("exp.sol") if (smt_conds.is_sat()): print "Solution:", for v in free_variables: if v.isReg(): if (v in ssa_map): print v,smt_conds.getValue(ssa_map[v]) elif v.isMem(): sname, offset = stack.read(v) v.mem_source = sname print v, smt_conds.getValue(v) else: print "Not exploitable"