Exemple #1
0
def _visit(node, visited, variables_written, variables_to_write):

    if node in visited:
        return []

    visited = visited + [node]

    refs = {}
    for ir in node.irs:
        if isinstance(ir, SolidityCall):
            # TODO convert the revert to a THROW node
            if ir.function in [
                    SolidityFunction('revert(string)'),
                    SolidityFunction('revert()')
            ]:
                return []

        if not isinstance(ir, OperationWithLValue):
            continue
        if isinstance(ir, (Index, Member)):
            refs[ir.lvalue] = ir.variable_left

        variables_written = variables_written + [ir.lvalue]
        lvalue = ir.lvalue
        while isinstance(lvalue, ReferenceVariable):
            variables_written = variables_written + [refs[lvalue]]
            lvalue = refs[lvalue]

    ret = []
    if not node.sons and not node.type in [NodeType.THROW, NodeType.RETURN]:
        ret += [v for v in variables_to_write if not v in variables_written]

    for son in node.sons:
        ret += _visit(son, visited, variables_written, variables_to_write)
    return ret
Exemple #2
0
def extract_tmp_call(ins):
    assert isinstance(ins, TmpCall)

    if isinstance(ins.called, Variable) and isinstance(ins.called.type,
                                                       FunctionType):
        call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
        call.call_id = ins.call_id
        return call
    if isinstance(ins.ori, Member):
        if isinstance(ins.ori.variable_left, Contract):
            st = ins.ori.variable_left.get_structure_from_name(
                ins.ori.variable_right)
            if st:
                op = NewStructure(st, ins.lvalue)
                op.call_id = ins.call_id
                return op
            libcall = LibraryCall(ins.ori.variable_left,
                                  ins.ori.variable_right, ins.nbr_arguments,
                                  ins.lvalue, ins.type_call)
            libcall.call_id = ins.call_id
            return libcall
        msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right,
                                ins.nbr_arguments, ins.lvalue, ins.type_call)
        msgcall.call_id = ins.call_id
        return msgcall

    if isinstance(ins.ori, TmpCall):
        r = extract_tmp_call(ins.ori)
        return r
    if isinstance(ins.called, SolidityVariableComposed):
        if str(ins.called) == 'block.blockhash':
            ins.called = SolidityFunction('blockhash(uint256)')
        elif str(ins.called) == 'this.balance':
            return SolidityCall(SolidityFunction('this.balance()'),
                                ins.nbr_arguments, ins.lvalue, ins.type_call)

    if isinstance(ins.called, SolidityFunction):
        return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue,
                            ins.type_call)

    if isinstance(ins.ori, TmpNewElementaryType):
        return NewElementaryType(ins.ori.type, ins.lvalue)

    if isinstance(ins.ori, TmpNewContract):
        op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
        op.call_id = ins.call_id
        return op

    if isinstance(ins.ori, TmpNewArray):
        return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)

    if isinstance(ins.called, Structure):
        op = NewStructure(ins.called, ins.lvalue)
        op.call_id = ins.call_id
        return op

    if isinstance(ins.called, Event):
        return EventCall(ins.called.name)

    raise Exception('Not extracted {} {}'.format(type(ins.called), ins))
def _visit(node: Node, state: State, variables_written: Set[Variable],
           variables_to_write: List[Variable]):
    """
    Explore all the nodes to look for values not written when the node's function return
    Fixpoint reaches if no new written variables are found

    :param node:
    :param state:
    :param variables_to_write:
    :return:
    """

    refs = {}
    variables_written = set(variables_written)
    for ir in node.irs:
        if isinstance(ir, SolidityCall):
            # TODO convert the revert to a THROW node
            if ir.function in [
                    SolidityFunction('revert(string)'),
                    SolidityFunction('revert()')
            ]:
                return []

        if not isinstance(ir, OperationWithLValue):
            continue
        if isinstance(ir, (Index, Member)):
            refs[ir.lvalue] = ir.variable_left
        if isinstance(ir, (Length, Balance)):
            refs[ir.lvalue] = ir.value

        if ir.lvalue and not isinstance(
                ir.lvalue, (TemporaryVariable, ReferenceVariable)):
            variables_written.add(ir.lvalue)

        lvalue = ir.lvalue
        while isinstance(lvalue, ReferenceVariable):
            if lvalue not in refs:
                break
            if refs[lvalue] and not isinstance(
                    refs[lvalue], (TemporaryVariable, ReferenceVariable)):
                variables_written.add(refs[lvalue])
            lvalue = refs[lvalue]

    ret = []
    if not node.sons and node.type not in [NodeType.THROW, NodeType.RETURN]:
        ret += [v for v in variables_to_write if v not in variables_written]

    # Explore sons if
    # - Before is none: its the first time we explored the node
    # - variables_written is not before: it means that this path has a configuration of set variables
    # that we haven't seen yet
    before = state.nodes[node] if node in state.nodes else None
    if before is None or variables_written not in before:
        state.nodes[node].append(variables_written)
        for son in node.sons:
            ret += _visit(son, state, variables_written, variables_to_write)
    return ret
Exemple #4
0
def _can_be_destroyed(contract) -> List[Function]:
    targets = []
    for f in contract.functions_entry_points:
        for ir in f.all_slithir_operations():
            if (isinstance(ir, LowLevelCall)
                    and ir.function_name in ["delegatecall", "codecall"]) or (
                        isinstance(ir, SolidityCall) and ir.function in [
                            SolidityFunction("suicide(address)"),
                            SolidityFunction("selfdestruct(address)")
                        ]):
                targets.append(f)
                break
    return targets
Exemple #5
0
def parse_yul_function_call(root: YulScope, node: YulNode,
                            ast: Dict) -> Optional[Expression]:
    args = [parse_yul(root, node, arg) for arg in ast["arguments"]]
    ident = parse_yul(root, node, ast["functionName"])

    if not isinstance(ident, Identifier):
        raise SlitherException(
            "expected identifier from parsing function name")

    if isinstance(ident.value, YulBuiltin):
        name = ident.value.name
        if name in binary_ops:
            if name in ["shl", "shr", "sar"]:
                # lmao ok
                return BinaryOperation(args[1], args[0], binary_ops[name])

            return BinaryOperation(args[0], args[1], binary_ops[name])

        if name in unary_ops:
            return UnaryOperation(args[0], unary_ops[name])

        ident = Identifier(
            SolidityFunction(format_function_descriptor(ident.value.name)))

    if isinstance(ident.value, Function):
        return CallExpression(ident, args,
                              vars_to_typestr(ident.value.returns))
    if isinstance(ident.value, SolidityFunction):
        return CallExpression(ident, args,
                              vars_to_typestr(ident.value.return_type))

    raise SlitherException(
        f"unexpected function call target type {str(type(ident.value))}")
Exemple #6
0
    def _post_member_access(self, expression):
        expr = get(expression.expression)

        # Look for type(X).max / min
        # Because we looked at the AST structure, we need to look into the nested expression
        # Hopefully this is always on a direct sub field, and there is no weird construction
        if isinstance(expression.expression, CallExpression) and expression.member_name in [
            "min",
            "max",
        ]:
            if isinstance(expression.expression.called, Identifier):
                if expression.expression.called.value == SolidityFunction("type()"):
                    assert len(expression.expression.arguments) == 1
                    val = TemporaryVariable(self._node)
                    type_expression_found = expression.expression.arguments[0]
                    assert isinstance(type_expression_found, ElementaryTypeNameExpression)
                    type_found = type_expression_found.type
                    if expression.member_name == "min:":
                        op = Assignment(val, Constant(str(type_found.min), type_found), type_found,)
                    else:
                        op = Assignment(val, Constant(str(type_found.max), type_found), type_found,)
                    self._result.append(op)
                    set_val(expression, val)
                    return

        val = ReferenceVariable(self._node)
        member = Member(expr, Constant(expression.member_name), val)
        member.set_expression(expression)
        self._result.append(member)
        set_val(expression, val)
Exemple #7
0
def convert_to_solidity_func(ir):
    """
    Must be called after can_be_solidity_func
    :param ir:
    :return:
    """
    call = SolidityFunction('abi.{}()'.format(ir.function_name))
    new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call)
    new_ir.arguments = ir.arguments
    new_ir.set_expression(ir.expression)
    if isinstance(call.return_type, list) and len(call.return_type) == 1:
        new_ir.lvalue.set_type(call.return_type[0])
    else:
        new_ir.lvalue.set_type(call.return_type)
    return new_ir
Exemple #8
0
def convert_to_low_level(ir):
    """
        Convert to a transfer/send/or low level call
        The funciton assume to receive a correct IR
        The checks must be done by the caller

        Additionally convert abi... to solidityfunction
    """
    if ir.function_name == 'transfer':
        assert len(ir.arguments) == 1
        ir = Transfer(ir.destination, ir.arguments[0])
        return ir
    elif ir.function_name == 'send':
        assert len(ir.arguments) == 1
        ir = Send(ir.destination, ir.arguments[0], ir.lvalue)
        ir.lvalue.set_type(ElementaryType('bool'))
        return ir
    elif ir.destination.name ==  'abi' and ir.function_name in ['encode',
                                                                'encodePacked',
                                                                'encodeWithSelector',
                                                                'encodeWithSignature',
                                                                'decode']:

        call = SolidityFunction('abi.{}()'.format(ir.function_name))
        new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call)
        new_ir.arguments = ir.arguments
        if isinstance(call.return_type, list) and len(call.return_type) == 1:
            new_ir.lvalue.set_type(call.return_type[0])
        else:
            new_ir.lvalue.set_type(call.return_type)
        return new_ir
    elif ir.function_name in ['call',
                              'delegatecall',
                              'callcode',
                              'staticcall']:
        new_ir = LowLevelCall(ir.destination,
                          ir.function_name,
                          ir.nbr_arguments,
                          ir.lvalue,
                          ir.type_call)
        new_ir.call_gas = ir.call_gas
        new_ir.call_value = ir.call_value
        new_ir.arguments = ir.arguments
        new_ir.lvalue.set_type(ElementaryType('bool'))
        return new_ir
    logger.error('Incorrect conversion to low level {}'.format(ir))
    exit(-1)
Exemple #9
0
def extract_tmp_call(ins, contract):
    assert isinstance(ins, TmpCall)

    if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
        call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
        call.set_expression(ins.expression)
        call.call_id = ins.call_id
        return call
    if isinstance(ins.ori, Member):
        # If there is a call on an inherited contract, it is an internal call or an event
        if ins.ori.variable_left in contract.inheritance + [contract]:
            if str(ins.ori.variable_right) in [f.name for f in contract.functions]:
                internalcall = InternalCall((ins.ori.variable_right, ins.ori.variable_left.name), ins.nbr_arguments,
                                            ins.lvalue, ins.type_call)
                internalcall.set_expression(ins.expression)
                internalcall.call_id = ins.call_id
                return internalcall
            if str(ins.ori.variable_right) in [f.name for f in contract.events]:
                eventcall = EventCall(ins.ori.variable_right)
                eventcall.set_expression(ins.expression)
                eventcall.call_id = ins.call_id
                return eventcall
        if isinstance(ins.ori.variable_left, Contract):
            st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
            if st:
                op = NewStructure(st, ins.lvalue)
                op.set_expression(ins.expression)
                op.call_id = ins.call_id
                return op
            libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue,
                                  ins.type_call)
            libcall.set_expression(ins.expression)
            libcall.call_id = ins.call_id
            return libcall
        msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue,
                                ins.type_call)
        msgcall.call_id = ins.call_id

        if ins.call_gas:
            msgcall.call_gas = ins.call_gas
        if ins.call_value:
            msgcall.call_value = ins.call_value
        msgcall.set_expression(ins.expression)

        return msgcall

    if isinstance(ins.ori, TmpCall):
        r = extract_tmp_call(ins.ori, contract)
        r.set_node(ins.node)
        return r
    if isinstance(ins.called, SolidityVariableComposed):
        if str(ins.called) == 'block.blockhash':
            ins.called = SolidityFunction('blockhash(uint256)')
        elif str(ins.called) == 'this.balance':
            s = SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)
            s.set_expression(ins.expression)
            return s

    if isinstance(ins.called, SolidityFunction):
        s = SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)
        s.set_expression(ins.expression)
        return s

    if isinstance(ins.ori, TmpNewElementaryType):
        n = NewElementaryType(ins.ori.type, ins.lvalue)
        n.set_expression(ins.expression)
        return n

    if isinstance(ins.ori, TmpNewContract):
        op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
        op.set_expression(ins.expression)
        op.call_id = ins.call_id
        return op

    if isinstance(ins.ori, TmpNewArray):
        n = NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
        n.set_expression(ins.expression)
        return n

    if isinstance(ins.called, Structure):
        op = NewStructure(ins.called, ins.lvalue)
        op.set_expression(ins.expression)
        op.call_id = ins.call_id
        op.set_expression(ins.expression)
        return op

    if isinstance(ins.called, Event):
        e = EventCall(ins.called.name)
        e.set_expression(ins.expression)
        return e

    if isinstance(ins.called, Contract):
        # Called a base constructor, where there is no constructor
        if ins.called.constructor is None:
            return Nop()
        # Case where:
        # contract A{ constructor(uint) }
        # contract B is A {}
        # contract C is B{ constructor() A(10) B() {}
        # C calls B(), which does not exist
        # Ideally we should compare here for the parameters types too
        if len(ins.called.constructor.parameters) != ins.nbr_arguments:
            return Nop()
        internalcall = InternalCall(ins.called.constructor, ins.nbr_arguments, ins.lvalue,
                                    ins.type_call)
        internalcall.call_id = ins.call_id
        internalcall.set_expression(ins.expression)
        return internalcall

    raise Exception('Not extracted {} {}'.format(type(ins.called), ins))
from slither.solc_parsing.declarations.function import FunctionSolc as Slither_FunctionSolc
from slither.core.solidity_types.elementary_type import ElementaryType

from .function_call import FunctionCall
from web3 import Web3

from collections import defaultdict

from util import get_boundary_values
from slither.utils.function import get_function_id

from slither.core.declarations import SolidityFunction
from slither.slithir.operations import SolidityCall

suicide_functions = [
    SolidityFunction("selfdestruct(address)"),
    SolidityFunction("suicide(address)")
]


class Function(FunctionCall):
    """
    still need to handle for ctfuzz
        para_names, "a, b, c"
        str_tc, string format of the tc
        next_tc, index for next test case
        new_ipm, flag for check if new ipm happened.
    """
    """
    Function objects
Exemple #11
0
"""
    Module printing summary of the contract
"""

from slither.core.declarations import SolidityFunction
from slither.printers.abstract_printer import AbstractPrinter
from slither.slithir.operations import SolidityCall
from slither.utils.myprettytable import MyPrettyTable

require_or_assert = [
    SolidityFunction("assert(bool)"),
    SolidityFunction("require(bool)"),
    SolidityFunction("require(bool,string)"),
]


class RequireOrAssert(AbstractPrinter):

    ARGUMENT = "require"
    HELP = "Print the require and assert calls of each function"

    WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#require"

    @staticmethod
    def _convert(l):
        return "\n".join(l)

    def output(self, _filename):
        """
        _filename is not used
        Args:
Exemple #12
0
    def _get_source_code(self, contract: Contract):  # pylint: disable=too-many-branches,too-many-statements
        """
        Save the source code of the contract in self._source_codes
        Patch the source code
        :param contract:
        :return:
        """
        src_mapping = contract.source_mapping
        content = self._slither.source_code[
            src_mapping["filename_absolute"]].encode("utf8")
        start = src_mapping["start"]
        end = src_mapping["start"] + src_mapping["length"]

        to_patch = []
        # interface must use external
        if self._external_to_public and contract.contract_kind != "interface":
            for f in contract.functions_declared:
                # fallback must be external
                if f.is_fallback or f.is_constructor_variables:
                    continue
                if f.visibility == "external":
                    attributes_start = (
                        f.parameters_src().source_mapping["start"] +
                        f.parameters_src().source_mapping["length"])
                    attributes_end = f.returns_src().source_mapping["start"]
                    attributes = content[attributes_start:attributes_end]
                    regex = re.search(
                        r"((\sexternal)\s+)|(\sexternal)$|(\)external)$",
                        attributes)
                    if regex:
                        to_patch.append(
                            Patch(
                                attributes_start + regex.span()[0] + 1,
                                "public_to_external",
                            ))
                    else:
                        raise SlitherException(
                            f"External keyword not found {f.name} {attributes}"
                        )

                    for var in f.parameters:
                        if var.location == "calldata":
                            calldata_start = var.source_mapping["start"]
                            calldata_end = calldata_start + var.source_mapping[
                                "length"]
                            calldata_idx = content[
                                calldata_start:calldata_end].find(" calldata ")
                            to_patch.append(
                                Patch(
                                    calldata_start + calldata_idx + 1,
                                    "calldata_to_memory",
                                ))

        if self._private_to_internal:
            for variable in contract.state_variables_declared:
                if variable.visibility == "private":
                    print(variable.source_mapping)
                    attributes_start = variable.source_mapping["start"]
                    attributes_end = attributes_start + variable.source_mapping[
                        "length"]
                    attributes = content[attributes_start:attributes_end]
                    print(attributes)
                    regex = re.search(r" private ", attributes)
                    if regex:
                        to_patch.append(
                            Patch(
                                attributes_start + regex.span()[0] + 1,
                                "private_to_internal",
                            ))
                    else:
                        raise SlitherException(
                            f"private keyword not found {variable.name} {attributes}"
                        )

        if self._remove_assert:
            for function in contract.functions_and_modifiers_declared:
                for node in function.nodes:
                    for ir in node.irs:
                        if isinstance(
                                ir, SolidityCall
                        ) and ir.function == SolidityFunction("assert(bool)"):
                            to_patch.append(
                                Patch(node.source_mapping["start"],
                                      "line_removal"))
                            logger.info(
                                f"Code commented: {node.expression} ({node.source_mapping_str})"
                            )

        to_patch.sort(key=lambda x: x.index, reverse=True)

        content = content[start:end]
        for patch in to_patch:
            patch_type = patch.patch_type
            index = patch.index
            index = index - start
            if patch_type == "public_to_external":
                content = content[:index] + "public" + content[
                    index + len("external"):]
            if patch_type == "private_to_internal":
                content = content[:index] + "internal" + content[
                    index + len("private"):]
            elif patch_type == "calldata_to_memory":
                content = content[:index] + "memory" + content[
                    index + len("calldata"):]
            else:
                assert patch_type == "line_removal"
                content = content[:index] + " // " + content[index:]

        self._source_codes[contract] = content.decode("utf8")
Exemple #13
0
    def _get_features(self, contract):  # pylint: disable=too-many-branches

        has_payable = False
        can_send_eth = False
        can_selfdestruct = False
        has_ecrecover = False
        can_delegatecall = False
        has_token_interaction = False

        has_assembly = False

        use_abi_encoder = False

        for pragma in self.slither.pragma_directives:
            if (
                pragma.source_mapping["filename_absolute"]
                == contract.source_mapping["filename_absolute"]
            ):
                if pragma.is_abi_encoder_v2:
                    use_abi_encoder = True

        for function in contract.functions:
            if function.payable:
                has_payable = True

            if function.contains_assembly:
                has_assembly = True

            for ir in function.slithir_operations:
                if isinstance(ir, (LowLevelCall, HighLevelCall, Send, Transfer)) and ir.call_value:
                    can_send_eth = True
                if isinstance(ir, SolidityCall) and ir.function in [
                    SolidityFunction("suicide(address)"),
                    SolidityFunction("selfdestruct(address)"),
                ]:
                    can_selfdestruct = True
                if isinstance(ir, SolidityCall) and ir.function == SolidityFunction(
                    "ecrecover(bytes32,uint8,bytes32,bytes32)"
                ):
                    has_ecrecover = True
                if isinstance(ir, LowLevelCall) and ir.function_name in [
                    "delegatecall",
                    "callcode",
                ]:
                    can_delegatecall = True
                if isinstance(ir, HighLevelCall):
                    if (
                        isinstance(ir.function, (Function, StateVariable))
                        and ir.function.contract.is_possible_token
                    ):
                        has_token_interaction = True

        return {
            "Receive ETH": has_payable,
            "Send ETH": can_send_eth,
            "Selfdestruct": can_selfdestruct,
            "Ecrecover": has_ecrecover,
            "Delegatecall": can_delegatecall,
            "Tokens interaction": has_token_interaction,
            "AbiEncoderV2": use_abi_encoder,
            "Assembly": has_assembly,
            "Upgradeable": contract.is_upgradeable,
            "Proxy": contract.is_upgradeable_proxy,
        }
    def _post_member_access(self, expression):
        expr = get(expression.expression)

        # Look for type(X).max / min
        # Because we looked at the AST structure, we need to look into the nested expression
        # Hopefully this is always on a direct sub field, and there is no weird construction
        if isinstance(expression.expression,
                      CallExpression) and expression.member_name in [
                          "min",
                          "max",
                      ]:
            if isinstance(expression.expression.called, Identifier):
                if expression.expression.called.value == SolidityFunction(
                        "type()"):
                    assert len(expression.expression.arguments) == 1
                    val = TemporaryVariable(self._node)
                    type_expression_found = expression.expression.arguments[0]
                    assert isinstance(type_expression_found,
                                      ElementaryTypeNameExpression)
                    type_found = type_expression_found.type
                    if expression.member_name == "min:":
                        op = Assignment(
                            val,
                            Constant(str(type_found.min), type_found),
                            type_found,
                        )
                    else:
                        op = Assignment(
                            val,
                            Constant(str(type_found.max), type_found),
                            type_found,
                        )
                    self._result.append(op)
                    set_val(expression, val)
                    return

        # This does not support solidity 0.4 contract_name.balance
        if (isinstance(expr, Variable)
                and expr.type == ElementaryType("address")
                and expression.member_name in ["balance", "code", "codehash"]):
            val = TemporaryVariable(self._node)
            name = expression.member_name + "(address)"
            sol_func = SolidityFunction(name)
            s = SolidityCall(
                sol_func,
                1,
                val,
                sol_func.return_type,
            )
            s.set_expression(expression)
            s.arguments.append(expr)
            self._result.append(s)
            set_val(expression, val)
            return

        if isinstance(expr, TypeAlias) and expression.member_name in [
                "wrap", "unwrap"
        ]:
            # The logic is be handled by _post_call_expression
            set_val(expression, expr)
            return

        # Early lookup to detect user defined types from other contracts definitions
        # contract A { type MyInt is int}
        # contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
        # The logic is handled by _post_call_expression
        if isinstance(expr, Contract):
            if expression.member_name in expr.file_scope.user_defined_types:
                set_val(
                    expression,
                    expr.file_scope.user_defined_types[expression.member_name])
                return

        val = ReferenceVariable(self._node)
        member = Member(expr, Constant(expression.member_name), val)
        member.set_expression(expression)
        self._result.append(member)
        set_val(expression, val)
Exemple #15
0
    def _get_source_code(self, contract):
        src_mapping = contract.source_mapping
        content = self._slither.source_code[src_mapping['filename_absolute']]
        start = src_mapping['start']
        end = src_mapping['start'] + src_mapping['length']

        to_patch = []
        # interface must use external
        if self._external_to_public and contract.contract_kind != "interface":
            for f in contract.functions_declared:
                # fallback must be external
                if f.is_fallback or f.is_constructor_variables:
                    continue
                if f.visibility == 'external':
                    attributes_start = (
                        f.parameters_src.source_mapping['start'] +
                        f.parameters_src.source_mapping['length'])
                    attributes_end = f.returns_src.source_mapping['start']
                    attributes = content[attributes_start:attributes_end]
                    regex = re.search(
                        r'((\sexternal)\s+)|(\sexternal)$|(\)external)$',
                        attributes)
                    if regex:
                        to_patch.append(
                            Patch(attributes_start + regex.span()[0] + 1,
                                  'public_to_external'))
                    else:
                        raise SlitherException(
                            f'External keyword not found {f.name} {attributes}'
                        )

                    for var in f.parameters:
                        if var.location == "calldata":
                            calldata_start = var.source_mapping['start']
                            calldata_end = calldata_start + var.source_mapping[
                                'length']
                            calldata_idx = content[
                                calldata_start:calldata_end].find(' calldata ')
                            to_patch.append(
                                Patch(calldata_start + calldata_idx + 1,
                                      'calldata_to_memory'))

        if self._private_to_internal:
            for variable in contract.state_variables_declared:
                if variable.visibility == 'private':
                    print(variable.source_mapping)
                    attributes_start = variable.source_mapping['start']
                    attributes_end = attributes_start + variable.source_mapping[
                        'length']
                    attributes = content[attributes_start:attributes_end]
                    print(attributes)
                    regex = re.search(r' private ', attributes)
                    if regex:
                        to_patch.append(
                            Patch(attributes_start + regex.span()[0] + 1,
                                  'private_to_internal'))
                    else:
                        raise SlitherException(
                            f'private keyword not found {v.name} {attributes}')

        if self._remove_assert:
            for function in contract.functions_and_modifiers_declared:
                for node in function.nodes:
                    for ir in node.irs:
                        if isinstance(
                                ir, SolidityCall
                        ) and ir.function == SolidityFunction('assert(bool)'):
                            to_patch.append(
                                Patch(node.source_mapping['start'],
                                      'line_removal'))
                            logger.info(
                                f'Code commented: {node.expression} ({node.source_mapping_str})'
                            )

        to_patch.sort(key=lambda x: x.index, reverse=True)

        content = content[start:end]
        for patch in to_patch:
            patch_type = patch.patch_type
            index = patch.index
            index = index - start
            if patch_type == 'public_to_external':
                content = content[:index] + 'public' + content[
                    index + len('external'):]
            if patch_type == 'private_to_internal':
                content = content[:index] + 'internal' + content[
                    index + len('private'):]
            elif patch_type == 'calldata_to_memory':
                content = content[:index] + 'memory' + content[
                    index + len('calldata'):]
            else:
                assert patch_type == 'line_removal'
                content = content[:index] + ' // ' + content[index:]

        self._source_codes[contract] = content
Exemple #16
0
def extract_tmp_call(ins, contract):
    assert isinstance(ins, TmpCall)

    if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
        call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
        call.call_id = ins.call_id
        return call
    if isinstance(ins.ori, Member):
        # If there is a call on an inherited contract, it is an internal call or an event
        if ins.ori.variable_left in contract.inheritance + [contract]:
            if str(ins.ori.variable_right) in [f.name for f in contract.functions]:
                internalcall = InternalCall((ins.ori.variable_right, ins.ori.variable_left.name), ins.nbr_arguments, ins.lvalue, ins.type_call)
                internalcall.call_id = ins.call_id
                return internalcall
            if str(ins.ori.variable_right) in [f.name for f in contract.events]:
               eventcall = EventCall(ins.ori.variable_right)
               eventcall.call_id = ins.call_id
               return eventcall
        if isinstance(ins.ori.variable_left, Contract):
            st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
            if st:
                op = NewStructure(st, ins.lvalue)
                op.call_id = ins.call_id
                return op
            libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
            libcall.call_id = ins.call_id
            return libcall
        msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
        msgcall.call_id = ins.call_id
        return msgcall

    if isinstance(ins.ori, TmpCall):
        r = extract_tmp_call(ins.ori, contract)
        return r
    if isinstance(ins.called, SolidityVariableComposed):
        if str(ins.called) == 'block.blockhash':
            ins.called = SolidityFunction('blockhash(uint256)')
        elif str(ins.called) == 'this.balance':
            return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)

    if isinstance(ins.called, SolidityFunction):
        return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)

    if isinstance(ins.ori, TmpNewElementaryType):
        return NewElementaryType(ins.ori.type, ins.lvalue)

    if isinstance(ins.ori, TmpNewContract):
        op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
        op.call_id = ins.call_id
        return op

    if isinstance(ins.ori, TmpNewArray):
        return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)

    if isinstance(ins.called, Structure):
        op = NewStructure(ins.called, ins.lvalue)
        op.call_id = ins.call_id
        return op

    if isinstance(ins.called, Event):
        return EventCall(ins.called.name)

    if isinstance(ins.called, Contract):
        # Called a base constructor, where there is no constructor
        if ins.called.constructor is None:
            return Nop()
        internalcall = InternalCall(ins.called.constructor, ins.nbr_arguments, ins.lvalue,
                                    ins.type_call)
        internalcall.call_id = ins.call_id
        return internalcall


    raise Exception('Not extracted {} {}'.format(type(ins.called), ins))