Ejemplo n.º 1
0
    def __init__(self, reil_code):

        # The first instruction should be a call
        self.callstack = [None]
        self.stack_diff = []

        self.index = 0
        #self.prev_callstack = [None]

        # aditional information need to compute the callstack
        self.calls = [None]
        self.esp_diffs = [None]
        self.reil_code = reil_code
        reil_size = len(reil_code)
        start = 0

        for (end, ins_str) in enumerate(self.reil_code):
            #print ins_str.strip("\n")
            pins = parse_reil(ins_str)
            ins = Instruction(pins, None)

            if (ins.instruction == "call" and ins.called_function
                    == None) or ins.instruction == "ret":
                self.__getStackDiff__(ins.instruction, ins.address,
                                      reil_code[start:end])
                start = end

        if (start <> reil_size - 1):
            pins = parse_reil(reil_code[start])
            self.__getStackDiff__(pins.instruction, pins.address,
                                  reil_code[start:reil_size])

        self.index = len(self.callstack) - 1
Ejemplo n.º 2
0
 def __init__(self, reil_code):
   
   # The first instruction should be a call
   self.callstack = [None]
   self.stack_diff = []
   
   self.index = 0
   #self.prev_callstack = [None]
   
   # aditional information need to compute the callstack
   self.calls = [None]
   self.esp_diffs = [None]
   self.reil_code = reil_code
   reil_size = len(reil_code)
   start = 0  
 
   for (end,ins_str) in enumerate(self.reil_code):
     #print ins_str.strip("\n")
     pins = parse_reil(ins_str)
     ins = Instruction(pins, None)
 
     if (ins.instruction == "call" and ins.called_function == None) or ins.instruction == "ret":
       self.__getStackDiff__(ins.instruction, ins.address,reil_code[start:end])
       start = end
       
   if (start <> reil_size-1):
     pins = parse_reil(reil_code[start])
     self.__getStackDiff__(pins.instruction, pins.address,reil_code[start:reil_size])
     
   self.index = len(self.callstack) - 1
Ejemplo n.º 3
0
    def detectFuncParameters(self, reil_code, memaccess, callstack, inputs,
                             counter):

        pins = parse_reil(reil_code[-1])
        ins = Instruction(pins, None)

        assert (ins.instruction == "call" and ins.called_function <> None)

        # first we locate the stack pointer to know where the parameters are located
        esp = Operand("esp", "DWORD")
        pbase = getTypedValueFromCode(reil_code, callstack, inputs, memaccess,
                                      esp)

        #print pbase.name
        #print pbase.mem_source
        #
        func_cons = funcs.get(ins.called_function, Function)
        func = func_cons(pbase=pbase)

        parameters = []

        for (par_type, location, needed) in func.getParameterLocations():
            #print (ins.called_function, par_type, location.mem_source, needed)
            if needed:
                reil_code.reverse()
                reil_code.reset()
                val = getTypedValueFromCode(reil_code, callstack, inputs,
                                            memaccess, location)
                #print  "parameter of",ins.called_function, "at", str(location) , "has value:", val.name
                parameters.append((location, val))
            else:
                parameters.append((None, None))

        if parameters <> []:
            self.parameters[counter] = self.__getParameters__(ins, parameters)
Ejemplo n.º 4
0
    def detectMemAccess(self, reil_code, callstack, inputs, counter):

        pins = parse_reil(reil_code[-1])
        ins = Instruction(pins, None)

        assert (ins.instruction in ["stm", "ldm"])
        addr_op = ins.getMemReg()
        #print "op:", addr_op, ins.address
        val = getTypedValueFromCode(reil_code, callstack, inputs, self,
                                    addr_op)
        #print val
        if (val.isMem()):

            #if self.__isArgMem__(val, callstack.callstack[1]):
            #  print "arg detected at", ins, "with", str(val)
            #  self.access[counter] = self.__getArgMemAccess__(ins, val, callstack.callstack[1])
            #else:
            #print val
            self.access[counter] = self.__getMemAccess__(ins, val)
        elif (val.isImm):
            self.access[counter] = self.__getGlobalMemAccess__(
                ins, int(val.name))

        else:
            assert (0)
Ejemplo n.º 5
0
 def detectFuncParameters(self, reil_code, memaccess, callstack, inputs, counter):
   
   pins = parse_reil(reil_code[-1])
   ins = Instruction(pins,None)
   
   assert(ins.instruction == "call" and ins.called_function <> None)
   
   # first we locate the stack pointer to know where the parameters are located
   esp = Operand("esp","DWORD")
   pbase = getTypedValueFromCode(reil_code, callstack, inputs, memaccess, esp)
   
   #print pbase.name
   #print pbase.mem_source
   #
   func_cons = funcs.get(ins.called_function, Function)
   func = func_cons(pbase = pbase)
   
   parameters = []
   
   for (par_type, location, needed) in func.getParameterLocations():
     #print (ins.called_function, par_type, location.mem_source, needed)
     if needed:
       reil_code.reverse()
       reil_code.reset()
       val = getTypedValueFromCode(reil_code, callstack, inputs, memaccess, location)
       #print  "parameter of",ins.called_function, "at", str(location) , "has value:", val.name
       parameters.append((location, val))
     else:
       parameters.append((None, None))
   
   if parameters <> []:
     self.parameters[counter] = self.__getParameters__(ins, parameters)
Ejemplo n.º 6
0
 def __getStackDiff__(self, inst, addr,reil_code):
   if inst == "call":
     call = int(addr, 16)
     esp_diff = self.__getESPdifference__(reil_code, 0) 
       
     self.calls.append(call)
     self.callstack.append(call)
       
     self.stack_diff.append(esp_diff)
     self.esp_diffs.append(esp_diff)
     
   elif inst == "ret":
       
     if (parse_reil(reil_code[0]).instruction == "call"):
       self.stack_diff.append(self.__getESPdifference__(reil_code, 0))
     else:
       self.calls.pop()
       self.esp_diffs.pop()
         
       call = self.calls[-1]
       esp_diff = self.esp_diffs[-1]
         
       self.stack_diff.append(self.__getESPdifference__(reil_code, esp_diff)) 
       self.callstack.append(call)
   else:
     assert(False)
Ejemplo n.º 7
0
def getJumpConditions(trace, addr):
    raw_ins = parse_reil(trace["code"][-1])
    addr = int(addr, 16)
    pos = trace["code"].last - 1

    if raw_ins.instruction == "jcc":
        ins = Instruction(raw_ins, None)
        jmp_op = ins.operands[2]

        if jmp_op.isVar():

            # print addr
            trace["final_conditions"] = dict([(jmp_op, Operand(str(addr), "DWORD"))])
            sol = getPathConditions(trace)

            if sol <> None:
                print "SAT conditions found!"
                filename = raw_ins.instruction + "[" + str(pos) + "]"
                dumped = sol.dump(filename, input_vars)
                for filename in dumped:
                    print filename, "dumped!"
            else:
                print "Impossible to jump to", hex(addr), "from", raw_ins.instruction, "at", pos
        else:
            return None

    else:
        return None
Ejemplo n.º 8
0
def getJumpConditions(trace, addr):
    raw_ins = parse_reil(trace["code"][-1])
    addr = int(addr, 16)
    pos = trace["code"].last - 1

    if (raw_ins.instruction == "jcc"):
        ins = Instruction(raw_ins, None)
        jmp_op = ins.operands[2]

        if (jmp_op.isVar()):

            #print addr
            trace["final_conditions"] = dict([(jmp_op,
                                               Operand(str(addr), "DWORD"))])
            sol = getPathConditions(trace)

            if (sol <> None):
                print "SAT conditions found!"
                filename = raw_ins.instruction + "[" + str(pos) + "]"
                dumped = sol.dump(filename, input_vars)
                for filename in dumped:
                    print filename, "dumped!"
            else:
                print "Impossible to jump to", hex(
                    addr), "from", raw_ins.instruction, "at", pos
        else:
            return None

    else:
        return None
Ejemplo n.º 9
0
    def __getStackDiff__(self, inst, addr, reil_code):
        if inst == "call":
            call = int(addr, 16)
            esp_diff = self.__getESPdifference__(reil_code, 0)

            self.calls.append(call)
            self.callstack.append(call)

            self.stack_diff.append(esp_diff)
            self.esp_diffs.append(esp_diff)

        elif inst == "ret":

            if (parse_reil(reil_code[0]).instruction == "call"):
                self.stack_diff.append(self.__getESPdifference__(reil_code, 0))
            else:
                self.calls.pop()
                self.esp_diffs.pop()

                call = self.calls[-1]
                esp_diff = self.esp_diffs[-1]

                self.stack_diff.append(
                    self.__getESPdifference__(reil_code, esp_diff))
                self.callstack.append(call)
        else:
            assert (False)
Ejemplo n.º 10
0
def getValueFromCode(inss, initial_values, op):
    assert (len(inss) > 0)

    # code should be copied and reversed
    inss.reverse()

    # counter is set
    counter = len(inss)

    ssa = SSA()
    smt_conds = SMT()

    # we will track op
    mvars = set([op])
    ssa.getMap(mvars, set(), set())

    for ins_str in inss:
        #print ins_str.strip("\n")
        pins = parse_reil(ins_str)

        ins = Instruction(pins, None)  # no memory and callstack are available

        ins_write_vars = set(ins.getWriteVarOperands())
        ins_read_vars = set(ins.getReadVarOperands())

        if len(ins_write_vars.intersection(mvars)) > 0:

            ssa_map = ssa.getMap(ins_read_vars.difference(mvars),
                                 ins_write_vars,
                                 ins_read_vars.intersection(mvars))

            cons = conds.get(pins.instruction, Condition)
            condition = cons(ins, ssa_map)

            mvars = mvars.difference(ins_write_vars)
            mvars = ins_read_vars.union(mvars)
            mvars = set(filter(lambda o: o.name <> "ebp", mvars))

            smt_conds.add(condition.getEq())

        counter = counter - 1

    for iop in initial_values.keys():
        if not (iop in ssa):
            del initial_values[iop]

    ssa_map = ssa.getMap(set(), set(), set(initial_values.keys()))
    eq = Eq(None, None)

    for iop in initial_values:
        smt_conds.add(eq.getEq(ssa_map[iop.name], initial_values[iop]))

    op.name = op.name + "_0"
    smt_conds.solve()

    return smt_conds.getValue(op)
Ejemplo n.º 11
0
def getValueFromCode(inss, initial_values, op):
  assert(len(inss) > 0)
  
  # code should be copied and reversed
  inss.reverse()
  
  # counter is set
  counter = len(inss)
  
  ssa = SSA()
  smt_conds  = SMT()
 
  # we will track op
  mvars = set([op])    
  ssa.getMap(mvars, set(), set())

  for ins_str in inss:
    #print ins_str.strip("\n")
    pins = parse_reil(ins_str)
      
    ins = Instruction(pins, None) # no memory and callstack are available
    
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())

    if len(ins_write_vars.intersection(mvars)) > 0: 
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
      mvars = set(filter(lambda o: o.name <> "ebp", mvars))
   
      smt_conds.add(condition.getEq())
      
    counter = counter - 1
  
  for iop in initial_values.keys():
    if not (iop in ssa):
      del initial_values[iop]
    
  ssa_map = ssa.getMap(set(), set(), set(initial_values.keys()))
  eq = Eq(None, None)
    
  for iop in initial_values:
    smt_conds.add(eq.getEq(ssa_map[iop.name],initial_values[iop]))
  
  op.name = op.name+"_0"
  smt_conds.solve()
  
  return smt_conds.getValue(op)
Ejemplo n.º 12
0
def getExploitConditions(trace, value, address, filename): 

  callstack  = trace["callstack"]
  inss       = trace["raw_code"]
  mem_access = trace["mem_access"]
  
  stm = parse_reil(inss[-1])
  
  if stm.instruction <> "stm":
    print "#ERROR: Selected instruction is not a store memory"
    return
  else:
    stm = Instruction(stm,None)
    addr_op = stm.getMemReg()
    val_op  = stm.getReadRegOperands()[0]
    
    getValueFromCode(inss[:-1], callstack, mem_access, addr_op, address, val_op, value)
Ejemplo n.º 13
0
def getPathConditions(trace, filename): 

  callstack  = list(trace["callstack"])
  inss       = list(trace["raw_code"])
  mem_access = trace["memaccess"]

  SSA.SSAinit()

  mvars = set()
  smt_conds = SMT() 

  #assert(False)

  for ins_str in inss:
    #print ins_str.strip("\n")
    # Instruction parsing
    pins = parse_reil(ins_str)

    # Instruction processing
    current_call = trace["current_call"]
    mem_access = trace["memaccess"].getAccess(pins.address)
    ins = Instruction(pins,current_call,mem_access)

    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())

    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0:
      ssa_map = SSA.SSAMapping(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
      mvars = set(filter(lambda o: o.name <> "ebp", mvars))
   
      smt_conds.add(condition.getEq())

    #print "mvars ops:"  
    #for op in mvars:
    #  print op

  smt_conds.write_smtlib_file(filename+".smt2")  
  smt_conds.write_sol_file(filename+".sol") 
Ejemplo n.º 14
0
 def detectMemAccess(self, reil_code, callstack, inputs, counter):
   
   pins = parse_reil(reil_code[-1])
   ins = Instruction(pins,None)
 
   assert(ins.instruction in ["stm", "ldm"])
   addr_op = ins.getMemReg()
   #print "op:", addr_op, ins.address
   val = getTypedValueFromCode(reil_code, callstack, inputs, self, addr_op)
   #print val
   if (val.isMem()):
     
     #if self.__isArgMem__(val, callstack.callstack[1]):
     #  print "arg detected at", ins, "with", str(val)
     #  self.access[counter] = self.__getArgMemAccess__(ins, val, callstack.callstack[1])
     #else:
     #print val
     self.access[counter] = self.__getMemAccess__(ins, val)
   elif (val.isImm):
     self.access[counter] = self.__getGlobalMemAccess__(ins, int(val.name))
   
   else:
     assert(0)
Ejemplo n.º 15
0
 def prevInstruction(self, ins):
   pins = parse_reil(ins)
   if pins.instruction == "call" or pins.instruction == "ret":
     self.index = self.index - 1
Ejemplo n.º 16
0
def getTypedValueFromCode(inss, callstack, initial_values, memory, op, debug = False):
  
  # Initialization
  
  # we reverse the code order
  inss.reverse()
  
  # we reset the used memory variables
  Memvars.reset()
  
  # we save the current callstack
  last_index = callstack.index  # TODO: create a better interface
  
  # we set the instruction counter
  counter = len(inss)-1
  
  # ssa and smt objects
  ssa = SSA()
  smt_conds  = SMT()
  
  val_type = None
  mvars = set()
 
  if (op.isImm()):
    return op
  elif (op.isMem()):
    for i in range(op.size):
      name = op.mem_source+"@"+str(op.mem_offset+i)
      mvars.add(Operand(name, "BYTE", op.mem_source, op.mem_offset+i))
      #print name
  else:
    # we will start tracking op
    mvars.add(op)
    
  # we start without free variables
  fvars = set()
  
  ssa.getMap(mvars, set(), set())

  for ins_str in inss:
    #print inss.current, "->", ins_str.strip("\n")
    #print ins_str.strip("\n")
    #for v in mvars:
    #  print v,
    #
    #print ""
    #
    pins = parse_reil(ins_str)
    ins = Instruction(pins, memory.getAccess(counter), mem_regs = False)
  
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())
 
    if len(ins_write_vars.intersection(mvars)) > 0: 
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
   
      smt_conds.add(condition.getEq())
    
    # simple typing
    new_val_type = detectType(mvars, ins, counter, callstack)
    
    # additional conditions
    mvars = addAditionalConditions(mvars, ins, ssa, callstack, smt_conds)
    
    val_type = max(val_type, new_val_type)


    # no more things to do
    # we update the counter 
    counter = counter - 1    
    # we update the current call for next instruction
    callstack.prevInstruction(ins_str) 
  
  if val_type == None:
    val_type = "imm"
  
  for v in mvars:
    if not (v in initial_values):
      print "#Warning__", str(v), "is free!" 
  
  setInitialConditions(ssa, initial_values, smt_conds)
  smt_conds.solve(debug)
  
  if op.isReg():
    op.name = op.name+"_0"
    
  elif op.isMem():
    op.mem_source = op.mem_source+"_0"
  
  callstack.index = last_index  # TODO: create a better interface
  if (debug):
    print val_type, op, smt_conds.getValue(op)
  return mkVal(val_type, smt_conds.getValue(op))
Ejemplo n.º 17
0
def getPathConditions(trace):
  
  inss = trace["code"]
  callstack = trace["callstack"]
  
  initial_values = trace["initial_conditions"]
  final_values = trace["final_conditions"]
  memory = trace["mem_access"]
  parameters = trace["func_parameters"]
  
  # we reverse the code order
  inss.reverse()
  
  # we reset the used memory variables
  Memvars.reset()
  
  # we set the instruction counter
  counter = len(inss)-1
  
  # ssa and smt objects
  ssa = SSA()
  smt_conds  = SMT()
  
  # auxiliary eq condition
  
  eq = Eq(None, None)
  mvars = set()
  
  # final conditions:
  
  for (op, _) in final_values.items():
    mvars.add(op)
  
  ssa.getMap(mvars, set(), set())
  setInitialConditions(ssa, final_values, smt_conds)
  
   
  # we start without free variables
  fvars = set()

  for ins_str in inss:
    #print ins_str.strip("\n")
    #for v in mvars:
    #  print v,
    #   
    pins = parse_reil(ins_str)
    ins = Instruction(pins, memory.getAccess(counter), mem_regs = False)  
  
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())
 
    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0:
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
   
      smt_conds.add(condition.getEq())
      
    elif (ins.instruction == "call" and ins.called_function <> None):
      
      func_cons = funcs.get(ins.called_function, Function)
      func = func_cons(None, parameters.getParameters(counter))
      
      func_write_vars = set(func.getWriteVarOperands())
      func_read_vars = set(func.getReadVarOperands())
      
      #for op in func_write_vars:
      #  print op
      
      if len(func_write_vars.intersection(mvars)) > 0:
        ssa_map = ssa.getMap(func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars))
        
        
        cons = conds.get(ins.called_function, Condition)
        condition = cons(func, None)
        
        c = condition.getEq(func_write_vars.intersection(mvars))
        
        mvars = mvars.difference(func_write_vars) 
        mvars = func_read_vars.union(mvars)
        
        smt_conds.add(c)
    
    # additional conditions
    
    mvars = addAditionalConditions(mvars, ins, ssa, callstack, smt_conds)
    

    # no more things to do
    # we update the counter 
    counter = counter - 1    
    # we update the current call for next instruction
    callstack.prevInstruction(ins_str) 
  
  #for v in mvars:
  #  print v
  
  fvars = filter(lambda v: not (v in initial_values.keys()), mvars)
  for v in fvars:
  #  print v,n
    if not (v in initial_values) and not (":" in v.name):
      print "#Warning", str(v), "is free!" 
  
  setInitialConditions(ssa, initial_values, smt_conds)
  
  if (smt_conds.is_sat()):
    smt_conds.solve()
  
    smt_conds.write_smtlib_file("exp.smt2")  
    smt_conds.write_sol_file("exp.sol")
  
    return Solution(smt_conds.m, fvars)
  else: # unsat :(
    return None
Ejemplo n.º 18
0
def mkTrace(trace_filename, first, last, raw_inputs):

    print "Loading trace.."
    reil_code = REIL_Trace(trace_filename, first, last)

    Inputs = parse_inputs(raw_inputs)

    if (raw_inputs <> []):
        print "Using these inputs.."

        for op in Inputs:
            print op, "=", Inputs[op]

    print "Detecting callstack layout..."
    Callstack = CallstackREIL(
        reil_code)  #, Inputs) #TODO: it should recieve inputs also!

    reil_code.reset()

    print Callstack

    AllocationLog = Allocation()
    MemAccess = MemAccessREIL()
    FuncParameters = FuncParametersREIL()

    reil_size = len(reil_code)
    start = 0

    Callstack.reset()

    print "Detecting memory accesses and function parameters.."

    for (end, ins_str) in enumerate(reil_code):
        pins = parse_reil(ins_str)
        ins = Instruction(pins, None)

        Callstack.nextInstruction(ins_str)

        if ins.instruction in ["stm", "ldm"]:
            MemAccess.detectMemAccess(reil_code[start:end + 1], Callstack,
                                      Inputs, end)
            AllocationLog.check(MemAccess.getAccess(end), end)

        elif ins.instruction == "call" and ins.called_function <> None:
            #print "detect parameters of", ins.called_function, "at", ins_str
            FuncParameters.detectFuncParameters(reil_code[start:end + 1],
                                                MemAccess, Callstack, Inputs,
                                                end)
            if (ins.called_function == "malloc"):

                try:
                    size = int(FuncParameters.getParameters(end)[0][1].name)
                except ValueError:
                    size = None
                AllocationLog.alloc(ins.address, end, size)
            elif (ins.called_function == "free"):
                ptr = (FuncParameters.getParameters(end)[0][1].mem_source)
                AllocationLog.free(ptr, end)

    print MemAccess
    print FuncParameters
    AllocationLog.report()

    Callstack.reset()
    reil_code.reset()

    # trace definition
    trace = dict()
    trace["code"] = reil_code
    trace["initial_conditions"] = Inputs
    trace["final_conditions"] = dict()
    trace["callstack"] = Callstack
    trace["mem_access"] = MemAccess
    trace["func_parameters"] = FuncParameters

    return trace
Ejemplo n.º 19
0
 def prevInstruction(self, ins):
     pins = parse_reil(ins)
     if pins.instruction == "call" or pins.instruction == "ret":
         self.index = self.index - 1
Ejemplo n.º 20
0
def getValueFromCode(reil_code, callstack, memory, addr_op, addr, val_op, val):

  assert(reil_code <> [])
  
  free_variables = []
  
  # code should be copied and reversed
  inss = list(reil_code) 
  inss.reverse()
  
  # counter is set
  counter = len(reil_code)
  
  tracked_stack_frame = callstack.index
  
  # especial operands in a call

  ssa = SSA()
  smt_conds  = SMT()
 
  # we will track op
  mvars = set([addr_op, val_op])    
  ssa_map = ssa.getMap(mvars, set(), set())
  eq = Eq(None, None)
  
  addr = Operand(str(addr), "DWORD")
  val = Operand(str(val), "BYTE")
  val.size = val_op.size
  
  smt_conds.add(eq.getEq(ssa_map[addr_op.name],addr))
  smt_conds.add(eq.getEq(ssa_map[val_op.name],val))

  for ins_str in inss:
    #print ins_str.strip("\n")
    
    pins = parse_reil(ins_str)
    
    ins = Instruction(pins, memory.getAccess(counter), mem_regs = False)
  
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())

    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0: 
    #if len(ins_write_vars.intersection(mvars)) > 0: 
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
   
      smt_conds.add(condition.getEq())
      
    counter = counter - 1
    
    if len(mvars) > 0:
      tracked_stack_frame = callstack.index
    
    if pins.instruction == "call":
      
      if callstack.index == 1:
        esp_val = 4
      else:
        esp_val = 8
 
      ebp_val = 0
  
      esp_op = Operand("esp","DWORD")
      ebp_op = Operand("ebp","DWORD")
  
      initial_values_at_call = dict()
      initial_values_at_call[esp_op] = Operand(str(esp_val), "DWORD")
      initial_values_at_call[ebp_op] = Operand(str(ebp_val), "DWORD") 

      
      for iop in initial_values_at_call.keys():
        if not (iop in mvars):
          del initial_values_at_call[iop]
      
      ssa_map = ssa.getMap(set(), set(), set(initial_values_at_call.keys()))
      eq = Eq(None, None)
    
      for iop in initial_values_at_call:
        smt_conds.add(eq.getEq(ssa_map[iop.name],initial_values_at_call[iop]))
      
      mvars = set(filter(lambda o: not (o in initial_values_at_call.keys()), mvars))
      
      if (counter == 0 and len(mvars)>0):
        
        #cond = Initial_Cond(None, None)
        #
        #for v in mvars:
        #  print str(v),
        #  smt_conds.add(cond.getEq(v))
        
        #print "are free"
        
        #print smt_conds.solver
        free_variables = mvars
        break
      
      new_mvars = set()
      for v in mvars:
        if v.isMem(): # this should work for stack memory 
          eop = callstack.convertStackMemOp(v)
          #print eop
          smt_conds.add(eq.getEq(v,eop))
          new_mvars.add(eop)
      
      mvars = set(filter(lambda o: not (o.isMem()), mvars))
      mvars = mvars.union(new_mvars)
    
    # we update the current call for next instruction
    callstack.prevInstruction(ins_str) 
    
  #op.name = op.name+"_0"
  smt_conds.solve()
  smt_conds.write_smtlib_file("exp.smt2")
  smt_conds.write_sol_file("exp.sol")
  
  if (smt_conds.is_sat()):
    print "Solution:",
    for v in free_variables:
      if v.isReg():
        if (v in ssa_map):
          print v,smt_conds.getValue(ssa_map[v])
      elif v.isMem():
        sname, offset = stack.read(v)
        v.mem_source = sname
        print v, smt_conds.getValue(v)
  else:
    print "Not exploitable"
Ejemplo n.º 21
0
def mkTrace(trace_filename, first, last, raw_inputs):
    
    print "Loading trace.."
    reil_code = REIL_Trace(trace_filename, first, last)
    
    Inputs = parse_inputs(raw_inputs)
    
    if (raw_inputs <> []):
      print "Using these inputs.."
    
      for op in Inputs:
        print op,"=", Inputs[op]
    
    print "Detecting callstack layout..."
    Callstack = CallstackREIL(reil_code)#, Inputs) #TODO: it should recieve inputs also!
    
    reil_code.reset()
    
    print Callstack
    
    AllocationLog = Allocation()
    MemAccess = MemAccessREIL()
    FuncParameters = FuncParametersREIL()
    
    reil_size = len(reil_code)
    start = 0  
  
    Callstack.reset()
    
    print "Detecting memory accesses and function parameters.."
  
    for (end,ins_str) in enumerate(reil_code):
      pins = parse_reil(ins_str)
      ins = Instruction(pins,None)
      
      Callstack.nextInstruction(ins_str)
      
      if ins.instruction in ["stm", "ldm"]: 
        MemAccess.detectMemAccess(reil_code[start:end+1], Callstack, Inputs, end)
        AllocationLog.check(MemAccess.getAccess(end), end)
        
      elif ins.instruction == "call" and ins.called_function <> None:
        #print "detect parameters of", ins.called_function, "at", ins_str
        FuncParameters.detectFuncParameters(reil_code[start:end+1], MemAccess, Callstack, Inputs, end)
        if (ins.called_function == "malloc"):
          
          try:
            size = int(FuncParameters.getParameters(end)[0][1].name)
          except ValueError:
            size = None
          AllocationLog.alloc(ins.address, end, size)
        elif (ins.called_function == "free"):
          ptr = (FuncParameters.getParameters(end)[0][1].mem_source)
          AllocationLog.free(ptr, end)
    
    
    print MemAccess
    print FuncParameters
    AllocationLog.report()
    
    
    Callstack.reset()
    reil_code.reset()
    
    # trace definition
    trace = dict()
    trace["code"] = reil_code
    trace["initial_conditions"] = Inputs
    trace["final_conditions"] = dict()
    trace["callstack"] = Callstack
    trace["mem_access"] = MemAccess
    trace["func_parameters"] = FuncParameters
    
    return trace