def xor_regs(tb, reg1, reg2): ret = tb.temporal(reg1.size) tb.add(ReilBuilder.gen_xor(reg1, reg2, ret)) return ret
def negate_reg(tb, reg): neg = tb.temporal(reg.size) tb.add(ReilBuilder.gen_xor(reg, all_ones_imm(tb, reg), neg)) return neg
class X86Translator(Translator): """x86 to IR Translator.""" def __init__(self, architecture_mode): super(X86Translator, self).__init__() # Set *Architecture Mode*. The translation of each instruction # into the REIL language is based on this. self._arch_mode = architecture_mode # An instance of *ArchitectureInformation*. self._arch_info = X86ArchitectureInformation(architecture_mode) # An instance of a *VariableNamer*. This is used so all the # temporary REIL registers are unique. self._ir_name_generator = VariableNamer("t", separator="") self._builder = ReilBuilder() self._flags = { "af": ReilRegisterOperand("af", 1), "cf": ReilRegisterOperand("cf", 1), "df": ReilRegisterOperand("df", 1), "of": ReilRegisterOperand("of", 1), "pf": ReilRegisterOperand("pf", 1), "sf": ReilRegisterOperand("sf", 1), "zf": ReilRegisterOperand("zf", 1), } if self._arch_mode == ARCH_X86_MODE_32: self._sp = ReilRegisterOperand("esp", 32) self._bp = ReilRegisterOperand("ebp", 32) self._ip = ReilRegisterOperand("eip", 32) self._ws = ReilImmediateOperand(4, 32) # word size elif self._arch_mode == ARCH_X86_MODE_64: self._sp = ReilRegisterOperand("rsp", 64) self._bp = ReilRegisterOperand("rbp", 64) self._ip = ReilRegisterOperand("rip", 64) self._ws = ReilImmediateOperand(8, 64) # word size def translate(self, instruction): """Return IR representation of an instruction. """ try: trans_instrs = self.__translate(instruction) except NotImplementedError: unkn_instr = self._builder.gen_unkn() unkn_instr.address = instruction.address << 8 | (0x0 & 0xff) trans_instrs = [unkn_instr] self.__log_not_supported_instruction(instruction) except: self.__log_translation_exception(instruction) raise return trans_instrs def reset(self): """Restart IR register name generator. """ self._ir_name_generator.reset() def __translate(self, instruction): """Translate a x86 instruction into REIL language. :param instruction: a x86 instruction :type instruction: X86Instruction """ # Retrieve translation function. mnemonic = instruction.mnemonic # Check whether it refers to the strings instruction or the sse instruction. if instruction.mnemonic in ["movsd"]: if instruction.bytes[0] not in ["\xa4", "\xa5"]: mnemonic += "_sse" # Translate instruction. if mnemonic in translators.dispatcher: tb = X86TranslationBuilder(self._ir_name_generator, self._arch_mode) translators.dispatcher[mnemonic](self, tb, instruction) else: raise NotImplementedError("Instruction Not Implemented") return tb.instanciate(instruction.address) def __log_not_supported_instruction(self, instruction): bytes_str = " ".join("%02x" % ord(b) for b in instruction.bytes) logger.info("Instruction not supported: %s (%s [%s])", instruction.mnemonic, instruction, bytes_str) def __log_translation_exception(self, instruction): bytes_str = " ".join("%02x" % ord(b) for b in instruction.bytes) logger.error("Failed to translate x86 to REIL: %s (%s)", instruction, bytes_str, exc_info=True) # Flag translation. # ======================================================================== # def _update_af(self, tb, oprnd0, oprnd1, result): assert oprnd0.size == oprnd1.size tmp0 = tb.temporal(8) tmp1 = tb.temporal(8) tmp2 = tb.temporal(8) tmp3 = tb.temporal(8) tmp4 = tb.temporal(8) tmp5 = tb.temporal(8) tmp6 = tb.temporal(8) imm4 = tb.immediate(4, 8) immn4 = tb.immediate(-4, 8) af = self._flags["af"] # Extract lower byte. tb.add(self._builder.gen_str(oprnd0, tmp0)) tb.add(self._builder.gen_str(oprnd1, tmp1)) # Zero-extend lower 4 bits. tb.add(self._builder.gen_bsh(tmp0, imm4, tmp2)) tb.add(self._builder.gen_bsh(tmp2, immn4, tmp4)) tb.add(self._builder.gen_bsh(tmp1, imm4, tmp3)) tb.add(self._builder.gen_bsh(tmp3, immn4, tmp5)) # Add up. tb.add(self._builder.gen_add(tmp4, tmp5, tmp6)) # Move bit 4 to AF flag. tb.add(self._builder.gen_bsh(tmp6, immn4, af)) def _update_af_sub(self, tb, oprnd0, oprnd1, result): assert oprnd0.size == oprnd1.size tmp0 = tb.temporal(8) tmp1 = tb.temporal(8) tmp2 = tb.temporal(8) tmp3 = tb.temporal(8) tmp4 = tb.temporal(8) tmp5 = tb.temporal(8) tmp6 = tb.temporal(8) imm4 = tb.immediate(4, 8) immn4 = tb.immediate(-4, 8) af = self._flags["af"] # Extract lower byte. tb.add(self._builder.gen_str(oprnd0, tmp0)) tb.add(self._builder.gen_str(oprnd1, tmp1)) # Zero-extend lower 4 bits. tb.add(self._builder.gen_bsh(tmp0, imm4, tmp2)) tb.add(self._builder.gen_bsh(tmp2, immn4, tmp4)) tb.add(self._builder.gen_bsh(tmp1, imm4, tmp3)) tb.add(self._builder.gen_bsh(tmp3, immn4, tmp5)) # Subtract tb.add(self._builder.gen_sub(tmp4, tmp5, tmp6)) # Move bit 4 to AF flag. tb.add(self._builder.gen_bsh(tmp6, immn4, af)) def _update_pf(self, tb, oprnd0, oprnd1, result): tmp0 = tb.temporal(result.size) tmp1 = tb.temporal(result.size) tmp2 = tb.temporal(result.size) tmp3 = tb.temporal(result.size) tmp4 = tb.temporal(result.size) tmp5 = tb.temporal(result.size) imm1 = tb.immediate(1, result.size) immn1 = tb.immediate(-1, result.size) immn2 = tb.immediate(-2, result.size) immn4 = tb.immediate(-4, result.size) pf = self._flags["pf"] # tmp1 = result ^ (result >> 4) tb.add(self._builder.gen_bsh(result, immn4, tmp0)) tb.add(self._builder.gen_xor(result, tmp0, tmp1)) # tmp3 = tmp1 ^ (tmp1 >> 2) tb.add(self._builder.gen_bsh(tmp1, immn2, tmp2)) tb.add(self._builder.gen_xor(tmp2, tmp1, tmp3)) # tmp5 = tmp3 ^ (tmp3 >> 1) tb.add(self._builder.gen_bsh(tmp3, immn1, tmp4)) tb.add(self._builder.gen_xor(tmp4, tmp3, tmp5)) # Invert and save result. tb.add(self._builder.gen_xor(tmp5, imm1, pf)) def _update_sf(self, tb, oprnd0, oprnd1, result): # Create temporal variables. tmp0 = tb.temporal(result.size) mask0 = tb.immediate(2**(oprnd0.size - 1), result.size) shift0 = tb.immediate(-(oprnd0.size - 1), result.size) sf = self._flags["sf"] tb.add(self._builder.gen_and(result, mask0, tmp0)) # filter sign bit tb.add(self._builder.gen_bsh(tmp0, shift0, sf)) # extract sign bit def _update_of(self, tb, oprnd0, oprnd1, result): assert oprnd0.size == oprnd1.size of = self._flags["of"] imm0 = tb.immediate(1, 1) tmp0 = tb.temporal(1) tmp1 = tb.temporal(1) tmp2 = tb.temporal(1) tmp3 = tb.temporal(1) # Extract sign bit. oprnd0_sign = self._extract_sign_bit(tb, oprnd0) oprnd1_sign = self._extract_sign_bit(tb, oprnd1) result_sign = self._extract_bit(tb, result, oprnd0.size - 1) # Compute OF. tb.add( self._builder.gen_xor(oprnd0_sign, oprnd1_sign, tmp0)) # (sign bit oprnd0 ^ sign bit oprnd1) tb.add(self._builder.gen_xor( tmp0, imm0, tmp1)) # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1) tb.add( self._builder.gen_xor(oprnd0_sign, result_sign, tmp2)) # (sign bit oprnd0 ^ sign bit result) tb.add( self._builder.gen_and(tmp1, tmp2, tmp3) ) # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1) & (sign bit oprnd0 ^ sign bit result) # Save result. tb.add(self._builder.gen_str(tmp3, of)) def _update_of_sub(self, tb, oprnd0, oprnd1, result): assert oprnd0.size == oprnd1.size of = self._flags["of"] imm0 = tb.immediate(1, 1) tmp0 = tb.temporal(1) tmp1 = tb.temporal(1) tmp2 = tb.temporal(1) tmp3 = tb.temporal(1) oprnd1_sign = tb.temporal(1) # Extract sign bit. oprnd0_sign = self._extract_sign_bit(tb, oprnd0) oprnd1_sign_tmp = self._extract_sign_bit(tb, oprnd1) result_sign = self._extract_bit(tb, result, oprnd0.size - 1) # Invert sign bit of oprnd2. tb.add(self._builder.gen_xor(oprnd1_sign_tmp, imm0, oprnd1_sign)) # Compute OF. tb.add( self._builder.gen_xor(oprnd0_sign, oprnd1_sign, tmp0)) # (sign bit oprnd0 ^ sign bit oprnd1) tb.add(self._builder.gen_xor( tmp0, imm0, tmp1)) # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1) tb.add( self._builder.gen_xor(oprnd0_sign, result_sign, tmp2)) # (sign bit oprnd0 ^ sign bit result) tb.add( self._builder.gen_and(tmp1, tmp2, tmp3) ) # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1) & (sign bit oprnd0 ^ sign bit result) # Save result. tb.add(self._builder.gen_str(tmp3, of)) def _update_cf(self, tb, oprnd0, oprnd1, result): cf = self._flags["cf"] imm0 = tb.immediate(2**oprnd0.size, result.size) imm1 = tb.immediate(-oprnd0.size, result.size) tmp0 = tb.temporal(result.size) tb.add(self._builder.gen_and(result, imm0, tmp0)) # filter carry bit tb.add(self._builder.gen_bsh(tmp0, imm1, cf)) def _update_zf(self, tb, oprnd0, oprnd1, result): zf = self._flags["zf"] imm0 = tb.immediate((2**oprnd0.size) - 1, result.size) tmp0 = tb.temporal(oprnd0.size) tb.add(self._builder.gen_and(result, imm0, tmp0)) # filter low part of result tb.add(self._builder.gen_bisz(tmp0, zf)) def _undefine_flag(self, tb, flag): # NOTE: In every test I've made, each time a flag is leave # undefined it is always set to 0. imm = tb.immediate(0, flag.size) tb.add(self._builder.gen_str(imm, flag)) def _clear_flag(self, tb, flag): imm = tb.immediate(0, flag.size) tb.add(self._builder.gen_str(imm, flag)) def _set_flag(self, tb, flag): imm = tb.immediate(1, flag.size) tb.add(self._builder.gen_str(imm, flag)) # Helpers. # ======================================================================== # def _evaluate_a(self, tb): # above (CF=0 and ZF=0). return tb._and_regs(tb._negate_reg(self._flags["cf"]), tb._negate_reg(self._flags["zf"])) def _evaluate_ae(self, tb): # above or equal (CF=0) return tb._negate_reg(self._flags["cf"]) def _evaluate_b(self, tb): # below (CF=1) return self._flags["cf"] def _evaluate_be(self, tb): # below or equal (CF=1 or ZF=1) return tb._or_regs(self._flags["cf"], self._flags["zf"]) def _evaluate_c(self, tb): # carry (CF=1) return self._flags["cf"] def _evaluate_e(self, tb): # equal (ZF=1) return self._flags["zf"] def _evaluate_g(self, tb): # greater (ZF=0 and SF=OF) return tb._and_regs( tb._negate_reg(self._flags["zf"]), tb._equal_regs(self._flags["sf"], self._flags["of"])) def _evaluate_ge(self, tb): # greater or equal (SF=OF) return tb._equal_regs(self._flags["sf"], self._flags["of"]) def _evaluate_l(self, tb): # less (SF != OF) return tb._unequal_regs(self._flags["sf"], self._flags["of"]) def _evaluate_le(self, tb): # less or equal (ZF=1 or SF != OF) return tb._or_regs( self._flags["zf"], tb._unequal_regs(self._flags["sf"], self._flags["of"])) def _evaluate_na(self, tb): # not above (CF=1 or ZF=1). return tb._or_regs(self._flags["cf"], self._flags["zf"]) def _evaluate_nae(self, tb): # not above or equal (CF=1) return self._flags["cf"] def _evaluate_nb(self, tb): # not below (CF=0) return tb._negate_reg(self._flags["cf"]) def _evaluate_nbe(self, tb): # not below or equal (CF=0 and ZF=0) return tb._and_regs(tb._negate_reg(self._flags["cf"]), tb._negate_reg(self._flags["zf"])) def _evaluate_nc(self, tb): # not carry (CF=0) return tb._negate_reg(self._flags["cf"]) def _evaluate_ne(self, tb): # not equal (ZF=0) return tb._negate_reg(self._flags["zf"]) def _evaluate_ng(self, tb): # not greater (ZF=1 or SF != OF) return tb._or_regs( self._flags["zf"], tb._unequal_regs(self._flags["sf"], self._flags["of"])) def _evaluate_nge(self, tb): # not greater or equal (SF != OF) return tb._unequal_regs(self._flags["sf"], self._flags["of"]) def _evaluate_nl(self, tb): # not less (SF=OF) return tb._equal_regs(self._flags["sf"], self._flags["of"]) def _evaluate_nle(self, tb): # not less or equal (ZF=0 and SF=OF) return tb._and_regs( tb._negate_reg(self._flags["zf"]), tb._equal_regs(self._flags["sf"], self._flags["of"])) def _evaluate_no(self, tb): # not overflow (OF=0) return tb._negate_reg(self._flags["of"]) def _evaluate_np(self, tb): # not parity (PF=0) return tb._negate_reg(self._flags["pf"]) def _evaluate_ns(self, tb): # not sign (SF=0) return tb._negate_reg(self._flags["sf"]) def _evaluate_nz(self, tb): # not zero (ZF=0) return tb._negate_reg(self._flags["zf"]) def _evaluate_o(self, tb): # overflow (OF=1) return self._flags["of"] def _evaluate_p(self, tb): # parity (PF=1) return self._flags["pf"] def _evaluate_pe(self, tb): # parity even (PF=1) return self._flags["pf"] def _evaluate_po(self, tb): # parity odd (PF=0) return tb._negate_reg(self._flags["pf"]) def _evaluate_s(self, tb): # sign (SF=1) return self._flags["sf"] def _evaluate_z(self, tb): # zero (ZF=1) return self._flags["zf"] # Helpers. # ======================================================================== # def _extract_bit(self, tb, reg, bit): assert (0 <= bit < reg.size) tmp = tb.temporal(reg.size) ret = tb.temporal(1) tb.add(self._builder.gen_bsh(reg, tb.immediate(-bit, reg.size), tmp)) # shift to LSB tb.add(self._builder.gen_and(tmp, tb.immediate(1, reg.size), ret)) # filter LSB return ret def _extract_msb(self, tb, reg): return self._extract_bit(tb, reg, reg.size - 1) def _extract_sign_bit(self, tb, reg): return self._extract_msb(tb, reg)
class TranslationBuilder(object): def __init__(self, ir_name_generator, architecture_information): self._ir_name_generator = ir_name_generator self._instructions = [] self._builder = ReilBuilder() self._arch_info = architecture_information def add(self, instr): self._instructions.append(instr) def temporal(self, size): return ReilRegisterOperand(self._ir_name_generator.get_next(), size) def immediate(self, value, size): return ReilImmediateOperand(value, size) def label(self, name): return Label(name) def instanciate(self, address): # Set instructions address. instrs = self._instructions for instr in instrs: instr.address = address << 8 instrs = self._resolve_loops(instrs) return instrs # Auxiliary functions # ======================================================================== # def _resolve_loops(self, instrs): idx_by_labels = {} # Collect labels. # curr = 0 # for index, instr in enumerate(instrs): # if isinstance(instr, Label): # idx_by_labels[instr.name] = curr # # del instrs[index] # else: # curr += 1 # TODO: Hack to avoid deleting while iterating instrs_no_labels = [] curr = 0 for i in instrs: if isinstance(i, Label): idx_by_labels[i.name] = curr else: instrs_no_labels.append(i) curr += 1 instrs[:] = instrs_no_labels # Resolve instruction addresses and JCC targets. for index, instr in enumerate(instrs): assert isinstance(instr, ReilInstruction) instr.address |= index if instr.mnemonic == ReilMnemonic.JCC: target = instr.operands[2] if isinstance(target, Label): idx = idx_by_labels[target.name] address = (instr.address & ~0xff) | idx instr.operands[2] = ReilImmediateOperand( address, self._arch_info.address_size + 8) return instrs def _all_ones_imm(self, reg): return self.immediate((2**reg.size) - 1, reg.size) def _negate_reg(self, reg): neg = self.temporal(reg.size) self.add(self._builder.gen_xor(reg, self._all_ones_imm(reg), neg)) return neg def _and_regs(self, reg1, reg2): ret = self.temporal(reg1.size) self.add(self._builder.gen_and(reg1, reg2, ret)) return ret def _or_regs(self, reg1, reg2): ret = self.temporal(reg1.size) self.add(self._builder.gen_or(reg1, reg2, ret)) return ret def _xor_regs(self, reg1, reg2): ret = self.temporal(reg1.size) self.add(self._builder.gen_xor(reg1, reg2, ret)) return ret def _equal_regs(self, reg1, reg2): return self._negate_reg(self._xor_regs(reg1, reg2)) def _unequal_regs(self, reg1, reg2): return self._xor_regs(reg1, reg2) def _shift_reg(self, reg, sh): ret = self.temporal(reg.size) self.add(self._builder.gen_bsh(reg, sh, ret)) return ret def _extract_bit(self, reg, bit): assert (0 <= bit < reg.size) tmp = self.temporal(reg.size) ret = self.temporal(1) self.add( self._builder.gen_bsh(reg, self.immediate(-bit, reg.size), tmp)) # shift to LSB self.add(self._builder.gen_and(tmp, self.immediate(1, reg.size), ret)) # filter LSB return ret # Same as before but the bit number is indicated by a register and it will be resolved at runtime def _extract_bit_with_register(self, reg, bit): # assert(bit >= 0 and bit < reg.size2) # It is assumed, it is not checked tmp = self.temporal(reg.size) neg_bit = self.temporal(reg.size) ret = self.temporal(1) self.add( self._builder.gen_sub( self.immediate(0, bit.size), bit, neg_bit)) # as left bit is indicated by a negative number self.add(self._builder.gen_bsh(reg, neg_bit, tmp)) # shift to LSB self.add(self._builder.gen_and(tmp, self.immediate(1, reg.size), ret)) # filter LSB return ret def _extract_msb(self, reg): return self._extract_bit(reg, reg.size - 1) def _extract_sign_bit(self, reg): return self._extract_msb(reg) def _greater_than_or_equal(self, reg1, reg2): assert (reg1.size == reg2.size) result = self.temporal(reg1.size * 2) self.add(self._builder.gen_sub(reg1, reg2, result)) sign = self._extract_bit(result, reg1.size - 1) overflow = self._overflow_from_sub(reg1, reg2, result) return self._equal_regs(sign, overflow) def _jump_to(self, target): self.add(self._builder.gen_jcc(self.immediate(1, 1), target)) def _jump_if_zero(self, reg, label): is_zero = self.temporal(1) self.add(self._builder.gen_bisz(reg, is_zero)) self.add(self._builder.gen_jcc(is_zero, label)) def _add_to_reg(self, reg, value): res = self.temporal(reg.size) self.add(self._builder.gen_add(reg, value, res)) return res def _sub_to_reg(self, reg, value): res = self.temporal(reg.size) self.add(self._builder.gen_sub(reg, value, res)) return res def _overflow_from_sub(self, oprnd0, oprnd1, result): op1_sign = self._extract_bit(oprnd0, oprnd0.size - 1) op2_sign = self._extract_bit(oprnd1, oprnd0.size - 1) res_sign = self._extract_bit(result, oprnd0.size - 1) return self._and_regs(self._unequal_regs(op1_sign, op2_sign), self._unequal_regs(op1_sign, res_sign))