예제 #1
0
def locate_transfer(vm, name):
    """Using specific parameters to locate transfer function.

    Args:
        vm: virtual machine includes all env information
        name: the name of contract

    Returns:

    """
    if global_vars.apply_function_address is None:
        return

    # check whether the type is valid transfer type
    apply_func_type = structure.FunctionType()
    apply_func_type.args = bytearray([bin_format.i64, bin_format.i64, bin_format.i64])
    apply_func_type.rets = bytearray()

    apply_func = vm.store.funcs[vm.module_instance.funcaddrs[global_vars.apply_function_address]]
    global_vars.locate()
    if apply_func.functype == apply_func_type:
        params = [utils.eos_abi_to_int(name), utils.eos_abi_to_int('eosio.token'), utils.eos_abi_to_int('transfer')]
        global_vars.locate()
        try:
            vm.exec_by_address(global_vars.apply_function_address, params)
        except AssertionError as e:
            logger.println(f'unreachable transfer: {e}')
        except SystemExit as e:
            logger.debugln(f'transfer found')
    global_vars.sym_exec()
예제 #2
0
파일: wana.py 프로젝트: lwy-someone/WANA
    def exec_by_address(self, address: int, args: typing.List, init_constraints: typing.List = ()):
        """Executing a function depends on its address.

        Args:
            address: the address of function of store.
            args: the parameters.
            init_constraints: initial constraints for symbolic execution.

        Returns:
            r: the result.
        """
        # Invoke a function denoted by the function address with the provided arguments.
        func = self.store.funcs[self.module_instance.funcaddrs[address]]

        if not isinstance(func, WasmFunc):
            return None

        # Mapping check for Python val-type to WebAssembly val-type.
        for i, e in enumerate(func.functype.args):
            if e in [bin_format.i32, bin_format.i64]:
                assert isinstance(args[i], int) or isinstance(args[i], z3.BitVecRef)

            if e in [bin_format.f32, bin_format.f64]:
                assert isinstance(args[i], float) or isinstance(args[i], z3.FPRef)

            args[i] = sym_exec.Value(e, args[i])
        stack = sym_exec.Stack()
        stack.ext(args)
        logger.debugln(f'Running function address {address}({", ".join([str(e) for e in args])}):')
        r = sym_exec.call(self.module_instance, address, self.store, stack, init_constraints)
        if r:
            return r
        return None
예제 #3
0
def main():
    """The main function of analysis. It executes static analysis and symbolic execution.
    The result will be stored in result.txt of project directory."""

    args = parse_arguments()

    # Compile
    if args.sol:
        global_vars.contract_type = 'ethereum'

    # Execute a export functions of wasm
    if args.execute:
        try:
            func_timeout(args.timeout,
                         execution_and_analyze,
                         args=(args.execute, ))
        except FunctionTimedOut:
            logger.println(f'{args.execute}: time out')
        except Exception as e:
            logger.debugln(traceback.format_exc())
            logger.println(f'Error: {e}')

    # Execute all export functions of wasm
    if args.analyse_directory:
        wasm_files = list_wasm_in_dir(args.analyse_directory)
        time_fp = open('output/time.txt', 'w')
        exception_fp = open('output/exception.txt', 'w')
        for contract_path in wasm_files:
            try:
                stamp = time.time()
                func_timeout(args.timeout,
                             execution_and_analyze,
                             args=(contract_path, ))
            except FunctionTimedOut:
                logger.println(f'{contract_path}: time out')
                exception_fp.write(
                    f'{os.path.basename(contract_path)}#Time Out')
            except Exception as e:
                logger.debugln(traceback.format_exc())
                logger.println(f'Error: {e}')
                exception_fp.write(f'{os.path.basename(contract_path)}#{e}')
            else:
                cost_time = time.time() - stamp
                time_fp.write(
                    f'{os.path.basename(contract_path)}#{cost_time}\n')

    # Count the number of instruction
    if args.count_instruction:
        wasm_files = list_wasm_in_dir(args.count_instruction)
        for file_name in wasm_files:
            logger.println('Count instruction of contract: ',
                           os.path.basename(file_name))
            try:
                vm = load(file_name)
            except Exception as e:
                logger.println(f'failed initialization: {e}')
                continue
            float_count, count = count_instruction(vm.module.funcs)
            logger.println(f'float: {float_count}  all: {count}')
예제 #4
0
파일: wana.py 프로젝트: lwy-someone/WANA
 def exec_all_func(self) -> None:
     """Executing all functions of the module.
     """
     for e in self.module_instance.exports:
         logger.debugln(e.name, e.value.addr)
         if e.value.extern_type == bin_format.extern_func:
             try:
                 self.exec_by_address(e.value.addr, self._get_symbolic_params(e.value.addr))
             except AttributeError as e:
                 logger.println('AttributeError found: ', e)
             except z3.z3types.Z3Exception as e:
                 logger.println('z3 BitVec format exception:, ', e)
             except Exception as e:
                 logger.println('Undefined error: ', e)
예제 #5
0
    def exec(self, name: str, args: typing.List):
        """Execute the export function.

        Args:
            name: the export function name.
            args: the parameters of function.

        Returns:
            r: the result of execution of function and args.
        """
        # Invoke a function denoted by the function address with the provided arguments.
        func_addr = self.func_addr(name)
        logger.debugln(f'Running function {name}):')
        r = self.exec_by_address(func_addr, args)
        if r:
            return r
        return None
예제 #6
0
def execution_and_analyze(contract_path):
    name = os.path.basename(contract_path).split('.wasm')[0]
    try:
        global_vars.vm = vm = load(contract_path)
    except Exception as e:
        logger.println(f'{e}: [{name}] failed initialization')
        return
    try:
        global_vars.set_name_int64(name)
    except Exception as e:
        logger.debugln(f'invalid contract name {name}: {e}')

    try:
        before_sym_exec(vm, name.split('_')[0])
        detect_fake_eos(vm, name.split('_')[0])
        after_sym_exec(name)
    except Exception as e:
        logger.println(f'Error: {e}')
    finally:
        global_vars.clear_count()
예제 #7
0
 def printex(instrs: typing.List[Instruction], prefix=0):
     for e in instrs:
         a = f'           | {" " * prefix}{bin_format.opcodes[e.code][0]}'
         if e.code in [
                 bin_format.block, bin_format.loop,
                 bin_format.if_
         ]:
             logger.debugln(
                 f'{a} {bin_format.blocktype[e.immediate_arguments][0]}'
             )
             prefix += 2
         elif e.code == bin_format.end:
             prefix -= 2
             a = f'           | {" " * prefix}{bin_format.opcodes[e.code][0]}'
             logger.debugln(f'{a}')
         elif e.immediate_arguments is None:
             logger.debugln(f'{a}')
         elif isinstance(e.immediate_arguments, list):
             logger.debugln(
                 f'{a} {" ".join([str(e) for e in e.immediate_arguments])}'
             )
         else:
             logger.debugln(f'{a} {e.immediate_arguments}')
예제 #8
0
    def from_reader(cls, r: typing.BinaryIO) -> 'Module':
        if list(r.read(4)) != [0x00, 0x61, 0x73, 0x6d]:
            raise Exception('Invalid magic number!')
        if list(r.read(4)) != [0x01, 0x00, 0x00, 0x00]:
            raise Exception('Invalid version!')
        mod = Module()
        while True:
            section_id_byte = r.read(1)
            if not section_id_byte:
                break
            section_id = ord(section_id_byte)
            n = bin_reader.read_count(r, 32)
            data = r.read(n)
            if len(data) != n:
                raise Exception('Invalid section size!')
            if section_id == bin_format.custom_section:
                custom_section = CustomSection.from_reader(io.BytesIO(data))
                logger.debugln(
                    f'{bin_format.section[section_id][0]:>9} {custom_section.name}'
                )
            elif section_id == bin_format.type_section:
                type_section = TypeSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(type_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.types = type_section.vec
            elif section_id == bin_format.import_section:
                import_section = ImportSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(import_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.imports = import_section.vec
            elif section_id == bin_format.function_section:
                function_section = FunctionSection.from_reader(
                    io.BytesIO(data))
                num_imported_funcs = sum(1 for _ in filter(
                    lambda ins: ins.kind == bin_format.extern_func,
                    mod.imports))
                for i, e in enumerate(function_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] func={num_imported_funcs + i} sig={e}'
                    )
            elif section_id == bin_format.table_section:
                table_section = TableSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(table_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.tables = table_section.vec
            elif section_id == bin_format.memory_section:
                memory_section = MemorySection.from_reader(io.BytesIO(data))
                for i, e in enumerate(memory_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.mems = memory_section.vec
            elif section_id == bin_format.global_section:
                global_section = GlobalSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(global_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.globals = global_section.vec
            elif section_id == bin_format.export_section:
                export_section = ExportSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(export_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.exports = export_section.vec
            elif section_id == bin_format.start_section:
                start_section = StartSection.from_reader(io.BytesIO(data))
                logger.debugln(
                    f'{bin_format.section[section_id][0]:>12} {start_section.start_function}'
                )
                mod.start = start_section.start_function.funcidx
            elif section_id == bin_format.element_section:
                element_section = ElementSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(element_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.elem = element_section.vec
            elif section_id == bin_format.code_section:
                code_section = CodeSection.from_reader(io.BytesIO(data))

                def printex(instrs: typing.List[Instruction], prefix=0):
                    for e in instrs:
                        a = f'           | {" " * prefix}{bin_format.opcodes[e.code][0]}'
                        if e.code in [
                                bin_format.block, bin_format.loop,
                                bin_format.if_
                        ]:
                            logger.debugln(
                                f'{a} {bin_format.blocktype[e.immediate_arguments][0]}'
                            )
                            prefix += 2
                        elif e.code == bin_format.end:
                            prefix -= 2
                            a = f'           | {" " * prefix}{bin_format.opcodes[e.code][0]}'
                            logger.debugln(f'{a}')
                        elif e.immediate_arguments is None:
                            logger.debugln(f'{a}')
                        elif isinstance(e.immediate_arguments, list):
                            logger.debugln(
                                f'{a} {" ".join([str(e) for e in e.immediate_arguments])}'
                            )
                        else:
                            logger.debugln(f'{a} {e.immediate_arguments}')

                num_imported_funcs = sum(1 for _ in filter(
                    lambda ins: ins.kind == bin_format.extern_func,
                    mod.imports))
                for i, e in enumerate(code_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] func={num_imported_funcs + i} {e}'
                    )
                    printex(e.expr.data)
                    func = Function()
                    func.typeidx = function_section.vec[i]
                    func.locals = e.locals
                    func.expr = e.expr
                    mod.funcs.append(func)
            elif section_id == bin_format.data_section:
                data_section = DataSection.from_reader(io.BytesIO(data))
                for i, e in enumerate(data_section.vec):
                    logger.debugln(
                        f'{bin_format.section[section_id][0]:>9}[{i}] {e}')
                mod.data = data_section.vec
            else:
                raise Exception('Invalid section id!')
        logger.debugln('')
        return mod