class X86Translator(Translator):
    """x86 to IR Translator."""
    def __init__(self, architecture_mode):

        super(X86Translator, self).__init__()

        # Set *Architecture Mode*. The translation of each instruction
        # into the REIL language is based on this.
        self._arch_mode = architecture_mode

        # An instance of *ArchitectureInformation*.
        self._arch_info = X86ArchitectureInformation(architecture_mode)

        # An instance of a *VariableNamer*. This is used so all the
        # temporary REIL registers are unique.
        self._ir_name_generator = VariableNamer("t", separator="")

        self._builder = ReilBuilder()

        self._flags = {
            "af": ReilRegisterOperand("af", 1),
            "cf": ReilRegisterOperand("cf", 1),
            "df": ReilRegisterOperand("df", 1),
            "of": ReilRegisterOperand("of", 1),
            "pf": ReilRegisterOperand("pf", 1),
            "sf": ReilRegisterOperand("sf", 1),
            "zf": ReilRegisterOperand("zf", 1),
        }

        if self._arch_mode == ARCH_X86_MODE_32:
            self._sp = ReilRegisterOperand("esp", 32)
            self._bp = ReilRegisterOperand("ebp", 32)
            self._ip = ReilRegisterOperand("eip", 32)

            self._ws = ReilImmediateOperand(4, 32)  # word size
        elif self._arch_mode == ARCH_X86_MODE_64:
            self._sp = ReilRegisterOperand("rsp", 64)
            self._bp = ReilRegisterOperand("rbp", 64)
            self._ip = ReilRegisterOperand("rip", 64)

            self._ws = ReilImmediateOperand(8, 64)  # word size

    def translate(self, instruction):
        """Return IR representation of an instruction.
        """
        try:
            trans_instrs = self.__translate(instruction)
        except NotImplementedError:
            unkn_instr = self._builder.gen_unkn()
            unkn_instr.address = instruction.address << 8 | (0x0 & 0xff)
            trans_instrs = [unkn_instr]

            self.__log_not_supported_instruction(instruction)
        except:
            self.__log_translation_exception(instruction)

            raise

        return trans_instrs

    def reset(self):
        """Restart IR register name generator.
        """
        self._ir_name_generator.reset()

    def __translate(self, instruction):
        """Translate a x86 instruction into REIL language.

        :param instruction: a x86 instruction
        :type instruction: X86Instruction
        """
        # Retrieve translation function.
        mnemonic = instruction.mnemonic

        # Check whether it refers to the strings instruction or the sse instruction.
        if instruction.mnemonic in ["movsd"]:
            if instruction.bytes[0] not in ["\xa4", "\xa5"]:
                mnemonic += "_sse"

        # Translate instruction.
        if mnemonic in translators.dispatcher:
            tb = X86TranslationBuilder(self._ir_name_generator,
                                       self._arch_mode)

            translators.dispatcher[mnemonic](self, tb, instruction)
        else:
            raise NotImplementedError("Instruction Not Implemented")

        return tb.instanciate(instruction.address)

    def __log_not_supported_instruction(self, instruction):
        bytes_str = " ".join("%02x" % ord(b) for b in instruction.bytes)

        logger.info("Instruction not supported: %s (%s [%s])",
                    instruction.mnemonic, instruction, bytes_str)

    def __log_translation_exception(self, instruction):
        bytes_str = " ".join("%02x" % ord(b) for b in instruction.bytes)

        logger.error("Failed to translate x86 to REIL: %s (%s)",
                     instruction,
                     bytes_str,
                     exc_info=True)

    # Flag translation.
    # ======================================================================== #
    def _update_af(self, tb, oprnd0, oprnd1, result):
        assert oprnd0.size == oprnd1.size

        tmp0 = tb.temporal(8)
        tmp1 = tb.temporal(8)
        tmp2 = tb.temporal(8)
        tmp3 = tb.temporal(8)
        tmp4 = tb.temporal(8)
        tmp5 = tb.temporal(8)
        tmp6 = tb.temporal(8)

        imm4 = tb.immediate(4, 8)
        immn4 = tb.immediate(-4, 8)

        af = self._flags["af"]

        # Extract lower byte.
        tb.add(self._builder.gen_str(oprnd0, tmp0))
        tb.add(self._builder.gen_str(oprnd1, tmp1))

        # Zero-extend lower 4 bits.
        tb.add(self._builder.gen_bsh(tmp0, imm4, tmp2))
        tb.add(self._builder.gen_bsh(tmp2, immn4, tmp4))

        tb.add(self._builder.gen_bsh(tmp1, imm4, tmp3))
        tb.add(self._builder.gen_bsh(tmp3, immn4, tmp5))

        # Add up.
        tb.add(self._builder.gen_add(tmp4, tmp5, tmp6))

        # Move bit 4 to AF flag.
        tb.add(self._builder.gen_bsh(tmp6, immn4, af))

    def _update_af_sub(self, tb, oprnd0, oprnd1, result):
        assert oprnd0.size == oprnd1.size

        tmp0 = tb.temporal(8)
        tmp1 = tb.temporal(8)
        tmp2 = tb.temporal(8)
        tmp3 = tb.temporal(8)
        tmp4 = tb.temporal(8)
        tmp5 = tb.temporal(8)
        tmp6 = tb.temporal(8)

        imm4 = tb.immediate(4, 8)
        immn4 = tb.immediate(-4, 8)

        af = self._flags["af"]

        # Extract lower byte.
        tb.add(self._builder.gen_str(oprnd0, tmp0))
        tb.add(self._builder.gen_str(oprnd1, tmp1))

        # Zero-extend lower 4 bits.
        tb.add(self._builder.gen_bsh(tmp0, imm4, tmp2))
        tb.add(self._builder.gen_bsh(tmp2, immn4, tmp4))

        tb.add(self._builder.gen_bsh(tmp1, imm4, tmp3))
        tb.add(self._builder.gen_bsh(tmp3, immn4, tmp5))

        # Subtract
        tb.add(self._builder.gen_sub(tmp4, tmp5, tmp6))

        # Move bit 4 to AF flag.
        tb.add(self._builder.gen_bsh(tmp6, immn4, af))

    def _update_pf(self, tb, oprnd0, oprnd1, result):
        tmp0 = tb.temporal(result.size)
        tmp1 = tb.temporal(result.size)
        tmp2 = tb.temporal(result.size)
        tmp3 = tb.temporal(result.size)
        tmp4 = tb.temporal(result.size)
        tmp5 = tb.temporal(result.size)

        imm1 = tb.immediate(1, result.size)
        immn1 = tb.immediate(-1, result.size)
        immn2 = tb.immediate(-2, result.size)
        immn4 = tb.immediate(-4, result.size)

        pf = self._flags["pf"]

        # tmp1 =  result ^ (result >> 4)
        tb.add(self._builder.gen_bsh(result, immn4, tmp0))
        tb.add(self._builder.gen_xor(result, tmp0, tmp1))

        # tmp3 =  tmp1 ^ (tmp1 >> 2)
        tb.add(self._builder.gen_bsh(tmp1, immn2, tmp2))
        tb.add(self._builder.gen_xor(tmp2, tmp1, tmp3))

        # tmp5 = tmp3 ^ (tmp3 >> 1)
        tb.add(self._builder.gen_bsh(tmp3, immn1, tmp4))
        tb.add(self._builder.gen_xor(tmp4, tmp3, tmp5))

        # Invert and save result.
        tb.add(self._builder.gen_xor(tmp5, imm1, pf))

    def _update_sf(self, tb, oprnd0, oprnd1, result):
        # Create temporal variables.
        tmp0 = tb.temporal(result.size)

        mask0 = tb.immediate(2**(oprnd0.size - 1), result.size)
        shift0 = tb.immediate(-(oprnd0.size - 1), result.size)

        sf = self._flags["sf"]

        tb.add(self._builder.gen_and(result, mask0, tmp0))  # filter sign bit
        tb.add(self._builder.gen_bsh(tmp0, shift0, sf))  # extract sign bit

    def _update_of(self, tb, oprnd0, oprnd1, result):
        assert oprnd0.size == oprnd1.size

        of = self._flags["of"]

        imm0 = tb.immediate(1, 1)

        tmp0 = tb.temporal(1)
        tmp1 = tb.temporal(1)
        tmp2 = tb.temporal(1)
        tmp3 = tb.temporal(1)

        # Extract sign bit.
        oprnd0_sign = self._extract_sign_bit(tb, oprnd0)
        oprnd1_sign = self._extract_sign_bit(tb, oprnd1)
        result_sign = self._extract_bit(tb, result, oprnd0.size - 1)

        # Compute OF.
        tb.add(
            self._builder.gen_xor(oprnd0_sign, oprnd1_sign,
                                  tmp0))  # (sign bit oprnd0 ^ sign bit oprnd1)
        tb.add(self._builder.gen_xor(
            tmp0, imm0, tmp1))  # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1)
        tb.add(
            self._builder.gen_xor(oprnd0_sign, result_sign,
                                  tmp2))  # (sign bit oprnd0 ^ sign bit result)
        tb.add(
            self._builder.gen_and(tmp1, tmp2, tmp3)
        )  # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1) & (sign bit oprnd0 ^ sign bit result)

        # Save result.
        tb.add(self._builder.gen_str(tmp3, of))

    def _update_of_sub(self, tb, oprnd0, oprnd1, result):
        assert oprnd0.size == oprnd1.size

        of = self._flags["of"]

        imm0 = tb.immediate(1, 1)

        tmp0 = tb.temporal(1)
        tmp1 = tb.temporal(1)
        tmp2 = tb.temporal(1)
        tmp3 = tb.temporal(1)

        oprnd1_sign = tb.temporal(1)

        # Extract sign bit.
        oprnd0_sign = self._extract_sign_bit(tb, oprnd0)
        oprnd1_sign_tmp = self._extract_sign_bit(tb, oprnd1)
        result_sign = self._extract_bit(tb, result, oprnd0.size - 1)

        # Invert sign bit of oprnd2.
        tb.add(self._builder.gen_xor(oprnd1_sign_tmp, imm0, oprnd1_sign))

        # Compute OF.
        tb.add(
            self._builder.gen_xor(oprnd0_sign, oprnd1_sign,
                                  tmp0))  # (sign bit oprnd0 ^ sign bit oprnd1)
        tb.add(self._builder.gen_xor(
            tmp0, imm0, tmp1))  # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1)
        tb.add(
            self._builder.gen_xor(oprnd0_sign, result_sign,
                                  tmp2))  # (sign bit oprnd0 ^ sign bit result)
        tb.add(
            self._builder.gen_and(tmp1, tmp2, tmp3)
        )  # (sign bit oprnd0 ^ sign bit oprnd1 ^ 1) & (sign bit oprnd0 ^ sign bit result)

        # Save result.
        tb.add(self._builder.gen_str(tmp3, of))

    def _update_cf(self, tb, oprnd0, oprnd1, result):
        cf = self._flags["cf"]

        imm0 = tb.immediate(2**oprnd0.size, result.size)
        imm1 = tb.immediate(-oprnd0.size, result.size)

        tmp0 = tb.temporal(result.size)

        tb.add(self._builder.gen_and(result, imm0, tmp0))  # filter carry bit
        tb.add(self._builder.gen_bsh(tmp0, imm1, cf))

    def _update_zf(self, tb, oprnd0, oprnd1, result):
        zf = self._flags["zf"]

        imm0 = tb.immediate((2**oprnd0.size) - 1, result.size)

        tmp0 = tb.temporal(oprnd0.size)

        tb.add(self._builder.gen_and(result, imm0,
                                     tmp0))  # filter low part of result
        tb.add(self._builder.gen_bisz(tmp0, zf))

    def _undefine_flag(self, tb, flag):
        # NOTE: In every test I've made, each time a flag is leave
        # undefined it is always set to 0.

        imm = tb.immediate(0, flag.size)

        tb.add(self._builder.gen_str(imm, flag))

    def _clear_flag(self, tb, flag):
        imm = tb.immediate(0, flag.size)

        tb.add(self._builder.gen_str(imm, flag))

    def _set_flag(self, tb, flag):
        imm = tb.immediate(1, flag.size)

        tb.add(self._builder.gen_str(imm, flag))

    # Helpers.
    # ======================================================================== #
    def _evaluate_a(self, tb):
        # above (CF=0 and ZF=0).
        return tb._and_regs(tb._negate_reg(self._flags["cf"]),
                            tb._negate_reg(self._flags["zf"]))

    def _evaluate_ae(self, tb):
        # above or equal (CF=0)
        return tb._negate_reg(self._flags["cf"])

    def _evaluate_b(self, tb):
        # below (CF=1)
        return self._flags["cf"]

    def _evaluate_be(self, tb):
        # below or equal (CF=1 or ZF=1)
        return tb._or_regs(self._flags["cf"], self._flags["zf"])

    def _evaluate_c(self, tb):
        # carry (CF=1)
        return self._flags["cf"]

    def _evaluate_e(self, tb):
        # equal (ZF=1)
        return self._flags["zf"]

    def _evaluate_g(self, tb):
        # greater (ZF=0 and SF=OF)
        return tb._and_regs(
            tb._negate_reg(self._flags["zf"]),
            tb._equal_regs(self._flags["sf"], self._flags["of"]))

    def _evaluate_ge(self, tb):
        # greater or equal (SF=OF)
        return tb._equal_regs(self._flags["sf"], self._flags["of"])

    def _evaluate_l(self, tb):
        # less (SF != OF)
        return tb._unequal_regs(self._flags["sf"], self._flags["of"])

    def _evaluate_le(self, tb):
        # less or equal (ZF=1 or SF != OF)
        return tb._or_regs(
            self._flags["zf"],
            tb._unequal_regs(self._flags["sf"], self._flags["of"]))

    def _evaluate_na(self, tb):
        # not above (CF=1 or ZF=1).
        return tb._or_regs(self._flags["cf"], self._flags["zf"])

    def _evaluate_nae(self, tb):
        # not above or equal (CF=1)
        return self._flags["cf"]

    def _evaluate_nb(self, tb):
        # not below (CF=0)
        return tb._negate_reg(self._flags["cf"])

    def _evaluate_nbe(self, tb):
        # not below or equal (CF=0 and ZF=0)
        return tb._and_regs(tb._negate_reg(self._flags["cf"]),
                            tb._negate_reg(self._flags["zf"]))

    def _evaluate_nc(self, tb):
        # not carry (CF=0)
        return tb._negate_reg(self._flags["cf"])

    def _evaluate_ne(self, tb):
        # not equal (ZF=0)
        return tb._negate_reg(self._flags["zf"])

    def _evaluate_ng(self, tb):
        # not greater (ZF=1 or SF != OF)
        return tb._or_regs(
            self._flags["zf"],
            tb._unequal_regs(self._flags["sf"], self._flags["of"]))

    def _evaluate_nge(self, tb):
        # not greater or equal (SF != OF)
        return tb._unequal_regs(self._flags["sf"], self._flags["of"])

    def _evaluate_nl(self, tb):
        # not less (SF=OF)
        return tb._equal_regs(self._flags["sf"], self._flags["of"])

    def _evaluate_nle(self, tb):
        # not less or equal (ZF=0 and SF=OF)
        return tb._and_regs(
            tb._negate_reg(self._flags["zf"]),
            tb._equal_regs(self._flags["sf"], self._flags["of"]))

    def _evaluate_no(self, tb):
        # not overflow (OF=0)
        return tb._negate_reg(self._flags["of"])

    def _evaluate_np(self, tb):
        # not parity (PF=0)
        return tb._negate_reg(self._flags["pf"])

    def _evaluate_ns(self, tb):
        # not sign (SF=0)
        return tb._negate_reg(self._flags["sf"])

    def _evaluate_nz(self, tb):
        # not zero (ZF=0)
        return tb._negate_reg(self._flags["zf"])

    def _evaluate_o(self, tb):
        # overflow (OF=1)
        return self._flags["of"]

    def _evaluate_p(self, tb):
        # parity (PF=1)
        return self._flags["pf"]

    def _evaluate_pe(self, tb):
        # parity even (PF=1)
        return self._flags["pf"]

    def _evaluate_po(self, tb):
        # parity odd (PF=0)
        return tb._negate_reg(self._flags["pf"])

    def _evaluate_s(self, tb):
        # sign (SF=1)
        return self._flags["sf"]

    def _evaluate_z(self, tb):
        # zero (ZF=1)
        return self._flags["zf"]

    # Helpers.
    # ======================================================================== #
    def _extract_bit(self, tb, reg, bit):
        assert (0 <= bit < reg.size)

        tmp = tb.temporal(reg.size)
        ret = tb.temporal(1)

        tb.add(self._builder.gen_bsh(reg, tb.immediate(-bit, reg.size),
                                     tmp))  # shift to LSB
        tb.add(self._builder.gen_and(tmp, tb.immediate(1, reg.size),
                                     ret))  # filter LSB

        return ret

    def _extract_msb(self, tb, reg):
        return self._extract_bit(tb, reg, reg.size - 1)

    def _extract_sign_bit(self, tb, reg):
        return self._extract_msb(tb, reg)
class TranslationBuilder(object):
    def __init__(self, ir_name_generator, architecture_information):
        self._ir_name_generator = ir_name_generator

        self._instructions = []

        self._builder = ReilBuilder()

        self._arch_info = architecture_information

    def add(self, instr):
        self._instructions.append(instr)

    def temporal(self, size):
        return ReilRegisterOperand(self._ir_name_generator.get_next(), size)

    def immediate(self, value, size):
        return ReilImmediateOperand(value, size)

    def label(self, name):
        return Label(name)

    def instanciate(self, address):
        # Set instructions address.
        instrs = self._instructions

        for instr in instrs:
            instr.address = address << 8

        instrs = self._resolve_loops(instrs)

        return instrs

    # Auxiliary functions
    # ======================================================================== #

    def _resolve_loops(self, instrs):
        idx_by_labels = {}

        # Collect labels.
        #         curr = 0
        #         for index, instr in enumerate(instrs):
        #             if isinstance(instr, Label):
        #                 idx_by_labels[instr.name] = curr
        #
        #                 del instrs[index]
        #             else:
        #                 curr += 1

        # TODO: Hack to avoid deleting while iterating
        instrs_no_labels = []
        curr = 0
        for i in instrs:
            if isinstance(i, Label):
                idx_by_labels[i.name] = curr
            else:
                instrs_no_labels.append(i)
                curr += 1

        instrs[:] = instrs_no_labels

        # Resolve instruction addresses and JCC targets.
        for index, instr in enumerate(instrs):
            assert isinstance(instr, ReilInstruction)

            instr.address |= index

            if instr.mnemonic == ReilMnemonic.JCC:
                target = instr.operands[2]

                if isinstance(target, Label):
                    idx = idx_by_labels[target.name]
                    address = (instr.address & ~0xff) | idx

                    instr.operands[2] = ReilImmediateOperand(
                        address, self._arch_info.address_size + 8)

        return instrs

    def _all_ones_imm(self, reg):
        return self.immediate((2**reg.size) - 1, reg.size)

    def _negate_reg(self, reg):
        neg = self.temporal(reg.size)
        self.add(self._builder.gen_xor(reg, self._all_ones_imm(reg), neg))
        return neg

    def _and_regs(self, reg1, reg2):
        ret = self.temporal(reg1.size)
        self.add(self._builder.gen_and(reg1, reg2, ret))
        return ret

    def _or_regs(self, reg1, reg2):
        ret = self.temporal(reg1.size)
        self.add(self._builder.gen_or(reg1, reg2, ret))
        return ret

    def _xor_regs(self, reg1, reg2):
        ret = self.temporal(reg1.size)
        self.add(self._builder.gen_xor(reg1, reg2, ret))
        return ret

    def _equal_regs(self, reg1, reg2):
        return self._negate_reg(self._xor_regs(reg1, reg2))

    def _unequal_regs(self, reg1, reg2):
        return self._xor_regs(reg1, reg2)

    def _shift_reg(self, reg, sh):
        ret = self.temporal(reg.size)
        self.add(self._builder.gen_bsh(reg, sh, ret))
        return ret

    def _extract_bit(self, reg, bit):
        assert (0 <= bit < reg.size)
        tmp = self.temporal(reg.size)
        ret = self.temporal(1)

        self.add(
            self._builder.gen_bsh(reg, self.immediate(-bit, reg.size),
                                  tmp))  # shift to LSB
        self.add(self._builder.gen_and(tmp, self.immediate(1, reg.size),
                                       ret))  # filter LSB

        return ret

    # Same as before but the bit number is indicated by a register and it will be resolved at runtime
    def _extract_bit_with_register(self, reg, bit):
        # assert(bit >= 0 and bit < reg.size2) # It is assumed, it is not checked
        tmp = self.temporal(reg.size)
        neg_bit = self.temporal(reg.size)
        ret = self.temporal(1)

        self.add(
            self._builder.gen_sub(
                self.immediate(0, bit.size), bit,
                neg_bit))  # as left bit is indicated by a negative number
        self.add(self._builder.gen_bsh(reg, neg_bit, tmp))  # shift to LSB
        self.add(self._builder.gen_and(tmp, self.immediate(1, reg.size),
                                       ret))  # filter LSB

        return ret

    def _extract_msb(self, reg):
        return self._extract_bit(reg, reg.size - 1)

    def _extract_sign_bit(self, reg):
        return self._extract_msb(reg)

    def _greater_than_or_equal(self, reg1, reg2):
        assert (reg1.size == reg2.size)
        result = self.temporal(reg1.size * 2)

        self.add(self._builder.gen_sub(reg1, reg2, result))

        sign = self._extract_bit(result, reg1.size - 1)
        overflow = self._overflow_from_sub(reg1, reg2, result)

        return self._equal_regs(sign, overflow)

    def _jump_to(self, target):
        self.add(self._builder.gen_jcc(self.immediate(1, 1), target))

    def _jump_if_zero(self, reg, label):
        is_zero = self.temporal(1)
        self.add(self._builder.gen_bisz(reg, is_zero))
        self.add(self._builder.gen_jcc(is_zero, label))

    def _add_to_reg(self, reg, value):
        res = self.temporal(reg.size)
        self.add(self._builder.gen_add(reg, value, res))

        return res

    def _sub_to_reg(self, reg, value):
        res = self.temporal(reg.size)
        self.add(self._builder.gen_sub(reg, value, res))

        return res

    def _overflow_from_sub(self, oprnd0, oprnd1, result):
        op1_sign = self._extract_bit(oprnd0, oprnd0.size - 1)
        op2_sign = self._extract_bit(oprnd1, oprnd0.size - 1)
        res_sign = self._extract_bit(result, oprnd0.size - 1)

        return self._and_regs(self._unequal_regs(op1_sign, op2_sign),
                              self._unequal_regs(op1_sign, res_sign))