Exemplo n.º 1
0
def getPathConditions(trace, filename): 

  callstack  = list(trace["callstack"])
  inss       = list(trace["raw_code"])
  mem_access = trace["memaccess"]

  SSA.SSAinit()

  mvars = set()
  smt_conds = SMT() 

  #assert(False)

  for ins_str in inss:
    #print ins_str.strip("\n")
    # Instruction parsing
    pins = parse_reil(ins_str)

    # Instruction processing
    current_call = trace["current_call"]
    mem_access = trace["memaccess"].getAccess(pins.address)
    ins = Instruction(pins,current_call,mem_access)

    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())

    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0:
      ssa_map = SSA.SSAMapping(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
      mvars = set(filter(lambda o: o.name <> "ebp", mvars))
   
      smt_conds.add(condition.getEq())

    #print "mvars ops:"  
    #for op in mvars:
    #  print op

  smt_conds.write_smtlib_file(filename+".smt2")  
  smt_conds.write_sol_file(filename+".sol") 
Exemplo n.º 2
0
def getPathConditions(trace, filename): 

  callstack  = list(trace["callstack"])
  inss       = list(trace["raw_code"])
  mem_access = trace["memaccess"]

  SSA.SSAinit()

  mvars = set()
  smt_conds = SMT() 

  #assert(False)

  for ins_str in inss:
    #print ins_str.strip("\n")
    # Instruction parsing
    pins = parse_reil(ins_str)

    # Instruction processing
    current_call = trace["current_call"]
    mem_access = trace["memaccess"].getAccess(pins.address)
    ins = Instruction(pins,current_call,mem_access)

    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())

    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0:
      ssa_map = SSA.SSAMapping(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
      mvars = set(filter(lambda o: o.name <> "ebp", mvars))
   
      smt_conds.add(condition.getEq())

    #print "mvars ops:"  
    #for op in mvars:
    #  print op

  smt_conds.write_smtlib_file(filename+".smt2")  
  smt_conds.write_sol_file(filename+".sol") 
Exemplo n.º 3
0
def getPathConditions(trace, debug = False):
  
  # Initialization
  inss = trace["code"]
  callstack = trace["callstack"]
  
  memory = trace["mem_access"]
  parameters = trace["func_parameters"]
 
  # we reverse the code order
  inss.reverse()
  #print inss[0]
  # we reset the used memory variables
  Memvars.reset()
  
  # we save the current callstack
  last_index = callstack.index  # TODO: create a better interface
  
  # ssa and smt objects
  ssa = SSA()
  smt_conds  = SMT()
  
  mvars = set()
  mlocs = set()

  for op in trace["final_conditions"]:
    mvars.add(op)
    mlocs = mlocs.union(op.getLocations())
  
  # we start without free variables
  fvars = set()
  
  ssa.getMap(mvars, set(), set())
  setInitialConditions(ssa, trace["final_conditions"],smt_conds)
  
  #for c in smt_conds:
  #  print c
  #assert(0)   

  for ins in inss:
    
    
    counter = ins.getCounter()
    func_cons = funcs.get(ins.called_function, Function)

    if memory.getAccess(counter) <> None:
      ins.setMemoryAccess(memory.getAccess(counter))

    ins.clearMemRegs() 
    func = func_cons(None, parameters.getParameters(counter))

    if debug:
      print "(%.4d)" % counter, ins
      for v in mvars:
        print v, v.getSizeInBytes(), "--",
      print ""
     
      for l in mlocs:
        print l, "--",
      print ""
  
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())
   
    func_write_vars = set(func.getWriteVarOperands())
    func_read_vars = set(func.getReadVarOperands())

    ins_write_locs = concatSet(map(lambda op: set(op.getLocations()), ins.getWriteVarOperands()))
    ins_read_locs  = concatSet(map(lambda op: set(op.getLocations()), ins.getReadVarOperands()))
    
    func_write_locs = concatSet(map(lambda op: set(op.getLocations()), func.getWriteVarOperands()))
    func_read_locs  = concatSet(map(lambda op: set(op.getLocations()), func.getReadVarOperands()))
    
    #if (func_write_vars <> set()):
    #  x =  func_write_vars.pop()
    #  print x, x.getLocations()
    #  assert(0)
    #print func, parameters.getParameters(counter), func_write_vars, func_write_locs 

    if (not ins.isCall()) and (ins.isJmp() or ins.isCJmp() or len(ins_write_locs.intersection(mlocs)) > 0): 
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(ins.instruction, Condition)
      condition = cons(ins, ssa_map)
      
      mlocs = mlocs.difference(ins_write_locs) 
      mlocs = ins_read_locs.union(mlocs) 
       
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
   
      smt_conds.add(condition.getEq())
      
    elif (len(func_write_locs.intersection(mlocs)) > 0):
      # TODO: clean-up here!
      #ssa_map = ssa.getMap(func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars))
        
      cons = conds.get(ins.called_function, Condition)
      condition = cons(func, None)
        
      c = condition.getEq(func_write_locs.intersection(mlocs))
      
      mlocs = mlocs.difference(func_write_locs) 
      mlocs = func_read_locs.union(mlocs) 
  
      mvars = mvars.difference(func_write_vars) 
      mvars = func_read_vars.union(mvars)

      smt_conds.add(c)
      #print c
      #assert(0)

    
    # additional conditions
    #mvars = addAditionalConditions(mvars, mlocs, ins, ssa, callstack, smt_conds)

    # we update the current call for next instruction
    callstack.prevInstruction(ins) 
  
  fvars = set()
  ssa_map = ssa.getMap(set(), set(), mvars)

  for var in mvars:
    #print v, "--",
    #if not (v in initial_values):
    print "#Warning__", str(var), "is free!" 
    
    if (var |iss| InputOp):
      fvars.add(var)
    elif var |iss| MemOp:
      f_op = var.copy()
      f_op.name = Memvars.read(var)
      fvars.add(f_op) 
    else:
      f_op = var.copy()
      f_op.name = f_op.name+"_0"
      fvars.add(f_op)
    #else:
      #fvars.add(ssa_map[str(var)])
      # perform SSA
      #assert(0)
  
  #setInitialConditions(ssa, initial_values, smt_conds)
  #smt_conds.solve(debug)
  
  callstack.index = last_index  # TODO: create a better interface
  smt_conds.write_smtlib_file("exp.smt2")  
  smt_conds.write_sol_file("exp.sol")
  smt_conds.solve(debug)  

  if (smt_conds.is_sat()):
    #smt_conds.solve(debug)
    return (fvars, Solution(smt_conds.m))
  else: # unsat :(
    return (set(), None)
Exemplo n.º 4
0
Arquivo: Common.py Projeto: YHVHvx/SEA
def getPathConditions(trace):

    inss = trace["code"]
    callstack = trace["callstack"]

    initial_values = trace["initial_conditions"]
    final_values = trace["final_conditions"]
    memory = trace["mem_access"]
    parameters = trace["func_parameters"]

    # we reverse the code order
    inss.reverse()

    # we reset the used memory variables
    Memvars.reset()

    # we set the instruction counter
    counter = len(inss) - 1

    # ssa and smt objects
    ssa = SSA()
    smt_conds = SMT()

    # auxiliary eq condition

    eq = Eq(None, None)
    mvars = set()

    # final conditions:

    for (op, _) in final_values.items():
        mvars.add(op)

    ssa.getMap(mvars, set(), set())
    setInitialConditions(ssa, final_values, smt_conds)

    # we start without free variables
    fvars = set()

    for ins in inss:

        if memory.getAccess(counter) <> None:
            ins.fixMemoryAccess(memory.getAccess(counter))

        ins_write_vars = set(ins.getWriteVarOperands())
        ins_read_vars = set(ins.getReadVarOperands())

        if ins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0:

            ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

            cons = conds.get(ins.instruction, Condition)
            condition = cons(ins, ssa_map)

            mvars = mvars.difference(ins_write_vars)
            mvars = ins_read_vars.union(mvars)

            smt_conds.add(condition.getEq())

        elif ins.isCall() and ins.called_function <> None:

            func_cons = funcs.get(ins.called_function, Function)
            func = func_cons(None, parameters.getParameters(counter))

            func_write_vars = set(func.getWriteVarOperands())
            func_read_vars = set(func.getReadVarOperands())

            if len(func_write_vars.intersection(mvars)) > 0:
                ssa_map = ssa.getMap(
                    func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars)
                )

                cons = conds.get(ins.called_function, Condition)
                condition = cons(func, None)

                c = condition.getEq(func_write_vars.intersection(mvars))

                mvars = mvars.difference(func_write_vars)
                mvars = func_read_vars.union(mvars)

                smt_conds.add(c)

        # additional conditions

        mvars = addAditionalConditions(mvars, ins, ssa, callstack, smt_conds)

        # no more things to do
        # we update the counter
        counter = counter - 1
        # we update the current call for next instruction
        callstack.prevInstruction(ins)

    # for v in mvars:
    #  print v

    fvars = filter(lambda v: not (v in initial_values.keys()), mvars)
    for v in fvars:
        #  print v,n
        if not (v in initial_values) and not (":" in v.name):
            print "#Warning", str(v), "is free!"

    setInitialConditions(ssa, initial_values, smt_conds)

    if smt_conds.is_sat():
        smt_conds.solve()

        smt_conds.write_smtlib_file("exp.smt2")
        smt_conds.write_sol_file("exp.sol")

        return Solution(smt_conds.m, fvars)
    else:  # unsat :(
        return None
Exemplo n.º 5
0
def getPathConditions(trace):
  
  inss = trace["code"]
  callstack = trace["callstack"]
  
  initial_values = trace["initial_conditions"]
  final_values = trace["final_conditions"]
  memory = trace["mem_access"]
  parameters = trace["func_parameters"]
  
  # we reverse the code order
  inss.reverse()
  
  # we reset the used memory variables
  Memvars.reset()
  
  # we set the instruction counter
  counter = len(inss)-1
  
  # ssa and smt objects
  ssa = SSA()
  smt_conds  = SMT()
  
  # auxiliary eq condition
  
  eq = Eq(None, None)
  mvars = set()
  
  # final conditions:
  
  for (op, _) in final_values.items():
    mvars.add(op)
  
  ssa.getMap(mvars, set(), set())
  setInitialConditions(ssa, final_values, smt_conds)
  
   
  # we start without free variables
  fvars = set()

  for ins_str in inss:
    #print ins_str.strip("\n")
    #for v in mvars:
    #  print v,
    #   
    pins = parse_reil(ins_str)
    ins = Instruction(pins, memory.getAccess(counter), mem_regs = False)  
  
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())
 
    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0:
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
   
      smt_conds.add(condition.getEq())
      
    elif (ins.instruction == "call" and ins.called_function <> None):
      
      func_cons = funcs.get(ins.called_function, Function)
      func = func_cons(None, parameters.getParameters(counter))
      
      func_write_vars = set(func.getWriteVarOperands())
      func_read_vars = set(func.getReadVarOperands())
      
      #for op in func_write_vars:
      #  print op
      
      if len(func_write_vars.intersection(mvars)) > 0:
        ssa_map = ssa.getMap(func_read_vars.difference(mvars), func_write_vars, func_read_vars.intersection(mvars))
        
        
        cons = conds.get(ins.called_function, Condition)
        condition = cons(func, None)
        
        c = condition.getEq(func_write_vars.intersection(mvars))
        
        mvars = mvars.difference(func_write_vars) 
        mvars = func_read_vars.union(mvars)
        
        smt_conds.add(c)
    
    # additional conditions
    
    mvars = addAditionalConditions(mvars, ins, ssa, callstack, smt_conds)
    

    # no more things to do
    # we update the counter 
    counter = counter - 1    
    # we update the current call for next instruction
    callstack.prevInstruction(ins_str) 
  
  #for v in mvars:
  #  print v
  
  fvars = filter(lambda v: not (v in initial_values.keys()), mvars)
  for v in fvars:
  #  print v,n
    if not (v in initial_values) and not (":" in v.name):
      print "#Warning", str(v), "is free!" 
  
  setInitialConditions(ssa, initial_values, smt_conds)
  
  if (smt_conds.is_sat()):
    smt_conds.solve()
  
    smt_conds.write_smtlib_file("exp.smt2")  
    smt_conds.write_sol_file("exp.sol")
  
    return Solution(smt_conds.m, fvars)
  else: # unsat :(
    return None
Exemplo n.º 6
0
def getValueFromCode(reil_code, callstack, memory, addr_op, addr, val_op, val):

  assert(reil_code <> [])
  
  free_variables = []
  
  # code should be copied and reversed
  inss = list(reil_code) 
  inss.reverse()
  
  # counter is set
  counter = len(reil_code)
  
  tracked_stack_frame = callstack.index
  
  # especial operands in a call

  ssa = SSA()
  smt_conds  = SMT()
 
  # we will track op
  mvars = set([addr_op, val_op])    
  ssa_map = ssa.getMap(mvars, set(), set())
  eq = Eq(None, None)
  
  addr = Operand(str(addr), "DWORD")
  val = Operand(str(val), "BYTE")
  val.size = val_op.size
  
  smt_conds.add(eq.getEq(ssa_map[addr_op.name],addr))
  smt_conds.add(eq.getEq(ssa_map[val_op.name],val))

  for ins_str in inss:
    #print ins_str.strip("\n")
    
    pins = parse_reil(ins_str)
    
    ins = Instruction(pins, memory.getAccess(counter), mem_regs = False)
  
    ins_write_vars = set(ins.getWriteVarOperands())
    ins_read_vars = set(ins.getReadVarOperands())

    if pins.instruction == "jcc" or len(ins_write_vars.intersection(mvars)) > 0: 
    #if len(ins_write_vars.intersection(mvars)) > 0: 
      
      ssa_map = ssa.getMap(ins_read_vars.difference(mvars), ins_write_vars, ins_read_vars.intersection(mvars))

      cons = conds.get(pins.instruction, Condition)
      condition = cons(ins, ssa_map)
     
      mvars = mvars.difference(ins_write_vars) 
      mvars = ins_read_vars.union(mvars)
   
      smt_conds.add(condition.getEq())
      
    counter = counter - 1
    
    if len(mvars) > 0:
      tracked_stack_frame = callstack.index
    
    if pins.instruction == "call":
      
      if callstack.index == 1:
        esp_val = 4
      else:
        esp_val = 8
 
      ebp_val = 0
  
      esp_op = Operand("esp","DWORD")
      ebp_op = Operand("ebp","DWORD")
  
      initial_values_at_call = dict()
      initial_values_at_call[esp_op] = Operand(str(esp_val), "DWORD")
      initial_values_at_call[ebp_op] = Operand(str(ebp_val), "DWORD") 

      
      for iop in initial_values_at_call.keys():
        if not (iop in mvars):
          del initial_values_at_call[iop]
      
      ssa_map = ssa.getMap(set(), set(), set(initial_values_at_call.keys()))
      eq = Eq(None, None)
    
      for iop in initial_values_at_call:
        smt_conds.add(eq.getEq(ssa_map[iop.name],initial_values_at_call[iop]))
      
      mvars = set(filter(lambda o: not (o in initial_values_at_call.keys()), mvars))
      
      if (counter == 0 and len(mvars)>0):
        
        #cond = Initial_Cond(None, None)
        #
        #for v in mvars:
        #  print str(v),
        #  smt_conds.add(cond.getEq(v))
        
        #print "are free"
        
        #print smt_conds.solver
        free_variables = mvars
        break
      
      new_mvars = set()
      for v in mvars:
        if v.isMem(): # this should work for stack memory 
          eop = callstack.convertStackMemOp(v)
          #print eop
          smt_conds.add(eq.getEq(v,eop))
          new_mvars.add(eop)
      
      mvars = set(filter(lambda o: not (o.isMem()), mvars))
      mvars = mvars.union(new_mvars)
    
    # we update the current call for next instruction
    callstack.prevInstruction(ins_str) 
    
  #op.name = op.name+"_0"
  smt_conds.solve()
  smt_conds.write_smtlib_file("exp.smt2")
  smt_conds.write_sol_file("exp.sol")
  
  if (smt_conds.is_sat()):
    print "Solution:",
    for v in free_variables:
      if v.isReg():
        if (v in ssa_map):
          print v,smt_conds.getValue(ssa_map[v])
      elif v.isMem():
        sname, offset = stack.read(v)
        v.mem_source = sname
        print v, smt_conds.getValue(v)
  else:
    print "Not exploitable"