def arbitrary_send(func): """ """ if func.is_protected(): return [] ret = [] for node in func.nodes: for ir in node.irs: if isinstance(ir, SolidityCall): if ir.function == SolidityFunction( 'ecrecover(bytes32,uint8,bytes32,bytes32)'): return False if isinstance(ir, Index): if ir.variable_right == SolidityVariableComposed( 'msg.sender'): return False if is_tainted(ir.variable_right, SolidityVariableComposed('msg.sender')): return False if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): if ir.call_value is None: continue if ir.call_value == SolidityVariableComposed('msg.value'): continue if is_tainted(ir.call_value, SolidityVariableComposed('msg.value')): continue if KEY in ir.context: if ir.context[KEY]: ret.append(node) return ret
def output(self, filename): """ Output the inheritance relation _filename is not used Args: _filename(string) """ payable = _extract_payable(self.slither) timestamp = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('block.timestamp')) block_number = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('block.number')) msg_sender = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('msg.sender')) msg_gas = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('msg.gas')) assert_usage = _extract_assert(self.slither) cst_functions = _extract_constant_functions(self.slither) (cst_used, cst_used_in_binary) = _extract_constants(self.slither) d = { 'payable': payable, 'timestamp': timestamp, 'block_number': block_number, 'msg_sender': msg_sender, 'msg_gas': msg_gas, 'assert': assert_usage, 'constant_functions': cst_functions, 'constants_used': cst_used, 'constants_used_in_binary': cst_used_in_binary } print(json.dumps(d, indent=4))
def arbitrary_send(self, func): """ """ if func.is_protected(): return [] ret = [] for node in func.nodes: for ir in node.irs: if isinstance(ir, SolidityCall): if ir.function == SolidityFunction('ecrecover(bytes32,uint8,bytes32,bytes32)'): return False if isinstance(ir, Index): if ir.variable_right == SolidityVariableComposed('msg.sender'): return False if is_dependent(ir.variable_right, SolidityVariableComposed('msg.sender'), func.contract): return False if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): if isinstance(ir, (HighLevelCall)): if isinstance(ir.function, Function): if ir.function.full_name == 'transferFrom(address,address,uint256)': return False if ir.call_value is None: continue if ir.call_value == SolidityVariableComposed('msg.value'): continue if is_dependent(ir.call_value, SolidityVariableComposed('msg.value'), func.contract): continue if is_tainted(ir.destination, func.contract, self.slither): ret.append(node) return ret
def detect(self): """ """ results = [] # Taint msg.value taint = SolidityVariableComposed('msg.value') run_taint_variable(self.slither, taint) # Taint msg.sender taint = SolidityVariableComposed('msg.sender') run_taint_variable(self.slither, taint) for c in self.contracts: arbitrary_send = self.detect_arbitrary_send(c) for (func, nodes) in arbitrary_send: info = "{}.{} ({}) sends eth to arbirary user\n" info = info.format(func.contract.name, func.name, func.source_mapping_str) info += '\tDangerous calls:\n' for node in nodes: info += '\t- {} ({})\n'.format(node.expression, node.source_mapping_str) self.log(info) json = self.generate_json_result(info) self.add_function_to_json(func, json) self.add_nodes_to_json(nodes, json) results.append(json) return results
def run_taint(slither, initial_taint=None): if initial_taint is None: initial_taint = [SolidityVariableComposed('msg.sender')] initial_taint += [SolidityVariableComposed('msg.value')] if KEY not in slither.context: _run_taint(slither, initial_taint)
def output(self, filename): # pylint: disable=too-many-locals """ Output the inheritance relation _filename is not used Args: _filename(string) """ payable = _extract_payable(self.slither) timestamp = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed("block.timestamp")) block_number = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed("block.number")) msg_sender = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed("msg.sender")) msg_gas = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed("msg.gas")) assert_usage = _extract_assert(self.slither) cst_functions = _extract_constant_functions(self.slither) (cst_used, cst_used_in_binary) = _extract_constants(self.slither) functions_relations = _extract_function_relations(self.slither) constructors = { contract.name: contract.constructor.full_name for contract in self.slither.contracts if contract.constructor } external_calls = _have_external_calls(self.slither) call_parameters = _call_a_parameter(self.slither) use_balance = _use_balance(self.slither) d = { "payable": payable, "timestamp": timestamp, "block_number": block_number, "msg_sender": msg_sender, "msg_gas": msg_gas, "assert": assert_usage, "constant_functions": cst_functions, "constants_used": cst_used, "constants_used_in_binary": cst_used_in_binary, "functions_relations": functions_relations, "constructors": constructors, "have_external_calls": external_calls, "call_a_parameter": call_parameters, "use_balance": use_balance, } self.info(json.dumps(d, indent=4)) res = self.generate_output(json.dumps(d, indent=4)) return res
def output(self, filename): """ Output the inheritance relation _filename is not used Args: _filename(string) """ payable = _extract_payable(self.slither) timestamp = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('block.timestamp')) block_number = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('block.number')) msg_sender = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('msg.sender')) msg_gas = _extract_solidity_variable_usage( self.slither, SolidityVariableComposed('msg.gas')) assert_usage = _extract_assert(self.slither) cst_functions = _extract_constant_functions(self.slither) (cst_used, cst_used_in_binary) = _extract_constants(self.slither) functions_relations = _extract_function_relations(self.slither) constructors = { contract.name: contract.constructor.full_name for contract in self.slither.contracts if contract.constructor } external_calls = _have_external_calls(self.slither) call_parameters = _call_a_parameter(self.slither) use_balance = _use_balance(self.slither) d = { 'payable': payable, 'timestamp': timestamp, 'block_number': block_number, 'msg_sender': msg_sender, 'msg_gas': msg_gas, 'assert': assert_usage, 'constant_functions': cst_functions, 'constants_used': cst_used, 'constants_used_in_binary': cst_used_in_binary, 'functions_relations': functions_relations, 'constructors': constructors, 'have_external_calls': external_calls, 'call_a_parameter': call_parameters, 'use_balance': use_balance } self.info(json.dumps(d, indent=4)) res = self.generate_output(json.dumps(d, indent=4)) return res
def detect(self): """ """ results = [] # Taint block.timestamp taint = SolidityVariableComposed('block.timestamp') run_taint_variable(self.slither, taint) for c in self.contracts: dangerous_timestamp = self.detect_dangerous_timestamp(c) for (func, nodes) in dangerous_timestamp: info = "{}.{} ({}) uses timestamp for comparisons\n" info = info.format(func.contract.name, func.name, func.source_mapping_str) info += '\tDangerous comparisons:\n' for node in nodes: info += '\t- {} ({})\n'.format(node.expression, node.source_mapping_str) self.log(info) json = self.generate_json_result(info) self.add_function_to_json(func, json) self.add_nodes_to_json(nodes, json) results.append(json) return results
def timestamp(self, func): """ """ ret = set() for node in func.nodes: if node.contains_require_or_assert(): for var in node.variables_read: if is_dependent(var, SolidityVariableComposed('block.timestamp'), func.contract): ret.add(node) for ir in node.irs: if isinstance(ir, Binary) and BinaryType.return_bool(ir.type): for var in ir.read: if is_dependent(var, SolidityVariableComposed('block.timestamp'), func.contract): ret.add(node) return list(ret)
def contains_bad_PRNG_sources( func: Function, blockhash_ret_values: List[Variable]) -> List[Node]: """ Check if any node in function has a modulus operator and the first operand is dependent on block.timestamp, now or blockhash() Returns: (nodes) """ ret = set() # pylint: disable=too-many-nested-blocks for node in func.nodes: for ir in node.irs_ssa: if isinstance(ir, Binary) and ir.type == BinaryType.MODULO: if is_dependent_ssa( ir.variable_left, SolidityVariableComposed("block.timestamp"), func.contract) or is_dependent_ssa( ir.variable_left, SolidityVariable("now"), func.contract): ret.add(node) break for ret_val in blockhash_ret_values: if is_dependent_ssa(ir.variable_left, ret_val, func.contract): ret.add(node) break return list(ret)
def _timestamp(func: Function) -> List[Node]: ret = set() for node in func.nodes: if node.contains_require_or_assert(): for var in node.variables_read: if is_dependent(var, SolidityVariableComposed("block.timestamp"), func.contract): ret.add(node) if is_dependent(var, SolidityVariable("now"), func.contract): ret.add(node) for ir in node.irs: if isinstance(ir, Binary) and BinaryType.return_bool(ir.type): for var in ir.read: if is_dependent( var, SolidityVariableComposed("block.timestamp"), func.contract ): ret.add(node) if is_dependent(var, SolidityVariable("now"), func.contract): ret.add(node) return sorted(list(ret), key=lambda x: x.node_id)
def detect(self): """ """ results = [] # Taint msg.value taint = SolidityVariableComposed('msg.value') run_taint_variable(self.slither, taint) # Taint msg.sender taint = SolidityVariableComposed('msg.sender') run_taint_variable(self.slither, taint) for c in self.contracts: arbitrary_send = self.detect_arbitrary_send(c) for (func, nodes) in arbitrary_send: calls_str = [str(node.expression) for node in nodes] info = "{}.{} ({}) sends eth to arbirary user\n" info = info.format(func.contract.name, func.name, func.source_mapping_str) info += '\tDangerous calls:\n' for node in nodes: info += '\t- {} ({})\n'.format(node.expression, node.source_mapping_str) self.log(info) source_mapping = [node.source_mapping for node in nodes] results.append({ 'vuln': 'ArbitrarySend', 'sourceMapping': source_mapping, 'filename': self.filename, 'contract': func.contract.name, 'function': func.name, 'calls': calls_str }) return results
def arbitrary_send(func: Function): if func.is_protected(): return [] ret: List[Node] = [] for node in func.nodes: for ir in node.irs: if isinstance(ir, SolidityCall): if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): return False if isinstance(ir, Index): if ir.variable_right == SolidityVariableComposed("msg.sender"): return False if is_dependent( ir.variable_right, SolidityVariableComposed("msg.sender"), func.contract, ): return False if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): if isinstance(ir, (HighLevelCall)): if isinstance(ir.function, Function): if ir.function.full_name == "transferFrom(address,address,uint256)": return False if ir.call_value is None: continue if ir.call_value == SolidityVariableComposed("msg.value"): continue if is_dependent( ir.call_value, SolidityVariableComposed("msg.value"), func.contract, ): continue if is_tainted(ir.destination, func.contract): ret.append(node) return ret
def detect(self): """ """ results = [] # Look if the destination of a call is tainted run_taint_calls(self.slither) # Taint msg.value taint = SolidityVariableComposed('msg.value') run_taint_variable(self.slither, taint) # Taint msg.sender taint = SolidityVariableComposed('msg.sender') run_taint_variable(self.slither, taint) for c in self.contracts: arbitrary_send = self.detect_arbitrary_send(c) for (func, nodes) in arbitrary_send: func_name = func.name calls_str = [str(node.expression) for node in nodes] txt = "Arbitrary send in {} Contract: {}, Function: {}, Calls: {}" info = txt.format(self.filename, c.name, func_name, calls_str) self.log(info) source_mapping = [node.source_mapping for node in nodes] results.append({ 'vuln': 'SuicidalFunc', 'sourceMapping': source_mapping, 'filename': self.filename, 'contract': c.name, 'func': func_name, 'calls': calls_str }) return results
def is_protected(self): """ Determine if the function is protected using a check on msg.sender Only detects if msg.sender is directly used in a condition For example, it wont work for: address a = msg.sender require(a == owner) Returns (bool) """ if self.is_constructor: return True conditional_vars = self.all_conditional_solidity_variables_read(include_loop=False) args_vars = self.all_solidity_variables_used_as_args() return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars
def front_running(func): """ Detect front running Args: func (Function) Returns: list(Node) """ if len(func.state_variables_read) > 0: return [] ret = [] for node in func.nodes: for ir in node.irs: if isinstance(ir, (Transfer, Send)): if ir.destination == SolidityVariableComposed( 'msg.sender'): ret.append(node) return ret
def detect_deprecation_in_expression(self, expression): """ Detects if an expression makes use of any deprecated standards. Returns: list of tuple: (detecting_signature, original_text, recommended_text)""" # Perform analysis on this expression export = ExportValues(expression) export_values = export.result() # Define our results list results = [] # Check if there is usage of any deprecated solidity variables or functions for dep_var in self.DEPRECATED_SOLIDITY_VARIABLE: if SolidityVariableComposed(dep_var[0]) in export_values: results.append(dep_var) for dep_func in self.DEPRECATED_SOLIDITY_FUNCTIONS: if SolidityFunction(dep_func[0]) in export_values: results.append(dep_func) return results
bool ''' assert isinstance(context, (Contract, Function)) context = context.context if isinstance(variable, Constant): return False if variable == source: return True if only_unprotected: return variable in context[KEY_SSA_UNPROTECTED] and source in context[ KEY_SSA_UNPROTECTED][variable] return variable in context[KEY_SSA] and source in context[KEY_SSA][variable] GENERIC_TAINT = { SolidityVariableComposed('msg.sender'), SolidityVariableComposed('msg.value'), SolidityVariableComposed('msg.data'), SolidityVariableComposed('tx.origin') } def is_tainted(variable, context, slither, only_unprotected=False): ''' Args: variable context (Contract|Function) only_unprotected (bool): True only unprotected function are considered Returns: bool '''
def parse_expression(expression, caller_context): """ Returns: str: expression """ # Expression # = Expression ('++' | '--') # | NewExpression # | IndexAccess # | MemberAccess # | FunctionCall # | '(' Expression ')' # | ('!' | '~' | 'delete' | '++' | '--' | '+' | '-') Expression # | Expression '**' Expression # | Expression ('*' | '/' | '%') Expression # | Expression ('+' | '-') Expression # | Expression ('<<' | '>>') Expression # | Expression '&' Expression # | Expression '^' Expression # | Expression '|' Expression # | Expression ('<' | '>' | '<=' | '>=') Expression # | Expression ('==' | '!=') Expression # | Expression '&&' Expression # | Expression '||' Expression # | Expression '?' Expression ':' Expression # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression # | PrimaryExpression # The AST naming does not follow the spec name = expression['name'] if name == 'UnaryOperation': attributes = expression['attributes'] assert 'prefix' in attributes operation_type = UnaryOperationType.get_type(attributes['operator'], attributes['prefix']) assert len(expression['children']) == 1 expression = parse_expression(expression['children'][0], caller_context) unary_op = UnaryOperation(expression, operation_type) return unary_op elif name == 'BinaryOperation': attributes = expression['attributes'] operation_type = BinaryOperationType.get_type(attributes['operator']) assert len(expression['children']) == 2 left_expression = parse_expression(expression['children'][0], caller_context) right_expression = parse_expression(expression['children'][1], caller_context) binary_op = BinaryOperation(left_expression, right_expression, operation_type) return binary_op elif name == 'FunctionCall': return parse_call(expression, caller_context) elif name == 'TupleExpression': if 'children' not in expression: attributes = expression['attributes'] components = attributes['components'] expressions = [ parse_expression(c, caller_context) if c else None for c in components ] else: expressions = [ parse_expression(e, caller_context) for e in expression['children'] ] t = TupleExpression(expressions) return t elif name == 'Conditional': children = expression['children'] assert len(children) == 3 if_expression = parse_expression(children[0], caller_context) then_expression = parse_expression(children[1], caller_context) else_expression = parse_expression(children[2], caller_context) conditional = ConditionalExpression(if_expression, then_expression, else_expression) #print(conditional) return conditional elif name == 'Assignment': attributes = expression['attributes'] children = expression['children'] assert len(expression['children']) == 2 left_expression = parse_expression(children[0], caller_context) right_expression = parse_expression(children[1], caller_context) operation_type = AssignmentOperationType.get_type( attributes['operator']) operation_return_type = attributes['type'] assignement = AssignmentOperation(left_expression, right_expression, operation_type, operation_return_type) return assignement elif name == 'Literal': assert 'children' not in expression value = expression['attributes']['value'] literal = Literal(value) return literal elif name == 'Identifier': assert 'children' not in expression value = expression['attributes']['value'] if 'type' in expression['attributes']: t = expression['attributes']['type'] if t: found = re.findall( '[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)', t) assert len(found) <= 1 if found: value = value + '(' + found[0] + ')' value = filter_name(value) var = find_variable(value, caller_context) identifier = Identifier(var) return identifier elif name == 'IndexAccess': index_type = expression['attributes']['type'] children = expression['children'] assert len(children) == 2 left_expression = parse_expression(children[0], caller_context) right_expression = parse_expression(children[1], caller_context) index = IndexAccess(left_expression, right_expression, index_type) return index elif name == 'MemberAccess': member_name = expression['attributes']['member_name'] member_type = expression['attributes']['type'] children = expression['children'] assert len(children) == 1 member_expression = parse_expression(children[0], caller_context) if str(member_expression) == 'super': super_name = parse_super_name(expression) if isinstance(caller_context, Contract): inheritance = caller_context.inheritance else: assert isinstance(caller_context, Function) inheritance = caller_context.contract.inheritance var = None for father in inheritance: try: var = find_variable(super_name, father) break except VariableNotFound: continue if var is None: raise VariableNotFound( 'Variable not found: {}'.format(super_name)) return SuperIdentifier(var) member_access = MemberAccess(member_name, member_type, member_expression) if str(member_access) in SOLIDITY_VARIABLES_COMPOSED: return Identifier(SolidityVariableComposed(str(member_access))) return member_access elif name == 'ElementaryTypeNameExpression': # nop exression # uint; assert 'children' not in expression value = expression['attributes']['value'] t = parse_type(UnknownType(value), caller_context) return ElementaryTypeNameExpression(t) # NewExpression is not a root expression, it's always the child of another expression elif name == 'NewExpression': new_type = expression['attributes']['type'] children = expression['children'] assert len(children) == 1 #new_expression = parse_expression(children[0]) child = children[0] if child['name'] == 'ArrayTypeName': depth = 0 while child['name'] == 'ArrayTypeName': # Note: dont conserve the size of the array if provided #assert len(child['children']) == 1 child = child['children'][0] depth += 1 if child['name'] == 'ElementaryTypeName': array_type = ElementaryType(child['attributes']['name']) elif child['name'] == 'UserDefinedTypeName': array_type = parse_type( UnknownType(child['attributes']['name']), caller_context) else: logger.error('Incorrect type array {}'.format(child)) exit(-1) array = NewArray(depth, array_type) return array if child['name'] == 'ElementaryTypeName': elem_type = ElementaryType(child['attributes']['name']) new_elem = NewElementaryType(elem_type) return new_elem assert child['name'] == 'UserDefinedTypeName' contract_name = child['attributes']['name'] new = NewContract(contract_name) return new elif name == 'ModifierInvocation': children = expression['children'] called = parse_expression(children[0], caller_context) arguments = [ parse_expression(a, caller_context) for a in children[1::] ] call = CallExpression(called, arguments, 'Modifier') return call logger.error('Expression not parsed %s' % name) exit(-1)
is_tainted(destination_indirect_1, contract, slither))) assert is_tainted(destination_indirect_1, contract, slither) destination_indirect_2 = contract.get_state_variable_from_name( 'destination_indirect_2') print('{} is tainted {}'.format( destination_indirect_2, is_tainted(destination_indirect_2, contract, slither))) assert is_tainted(destination_indirect_2, contract, slither) print('SolidityVar contract') contract = slither.get_contract_from_name('SolidityVar') addr_1 = contract.get_state_variable_from_name('addr_1') addr_2 = contract.get_state_variable_from_name('addr_2') msgsender = SolidityVariableComposed('msg.sender') print('{} is dependent of {}: {}'.format( addr_1, msgsender, is_dependent(addr_1, msgsender, contract))) assert is_dependent(addr_1, msgsender, contract) print('{} is dependent of {}: {}'.format( addr_2, msgsender, is_dependent(addr_2, msgsender, contract))) assert not is_dependent(addr_2, msgsender, contract) print('Intermediate contract') contract = slither.get_contract_from_name('Intermediate') destination = contract.get_state_variable_from_name('destination') source = contract.get_state_variable_from_name('source') print('{} is dependent of {}: {}'.format( destination, source, is_dependent(destination, source, contract))) assert is_dependent(destination, source, contract)
assert is_tainted(destination_indirect_1, contract) destination_indirect_2 = contract.get_state_variable_from_name( "destination_indirect_2") print("{} is tainted {}".format(destination_indirect_2, is_tainted(destination_indirect_2, contract))) assert is_tainted(destination_indirect_2, contract) print("SolidityVar contract") contract = slither.get_contract_from_name("SolidityVar") assert contract addr_1 = contract.get_state_variable_from_name("addr_1") assert addr_1 addr_2 = contract.get_state_variable_from_name("addr_2") assert addr_2 msgsender = SolidityVariableComposed("msg.sender") print("{} is dependent of {}: {}".format( addr_1, msgsender, is_dependent(addr_1, msgsender, contract))) assert is_dependent(addr_1, msgsender, contract) print("{} is dependent of {}: {}".format( addr_2, msgsender, is_dependent(addr_2, msgsender, contract))) assert not is_dependent(addr_2, msgsender, contract) print("Intermediate contract") contract = slither.get_contract_from_name("Intermediate") assert contract destination = contract.get_state_variable_from_name("destination") assert destination source = contract.get_state_variable_from_name("source") assert source
for ir in node.irs: if isinstance(ir, HighLevelCall): if ir.destination in taints: print("Call to tainted address found in {}".format( function.name)) if __name__ == "__main__": if len(sys.argv) != 2: print("python taint_mapping.py taint.sol") sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) initial_taint = [SolidityVariableComposed("msg.sender")] initial_taint += [SolidityVariableComposed("msg.value")] KEY = "TAINT" prev_taints = [] slither.context[KEY] = initial_taint while set(prev_taints) != set(slither.context[KEY]): prev_taints = slither.context[KEY] for contract in slither.contracts: for function in contract.functions: print("Function {}".format(function.name)) slither.context[KEY] = list( set(slither.context[KEY] + function.parameters)) visit_node(function.entry_point, []) print("All variables tainted : {}".format(
def parse_expression(expression, caller_context): """ Returns: str: expression """ # Expression # = Expression ('++' | '--') # | NewExpression # | IndexAccess # | MemberAccess # | FunctionCall # | '(' Expression ')' # | ('!' | '~' | 'delete' | '++' | '--' | '+' | '-') Expression # | Expression '**' Expression # | Expression ('*' | '/' | '%') Expression # | Expression ('+' | '-') Expression # | Expression ('<<' | '>>') Expression # | Expression '&' Expression # | Expression '^' Expression # | Expression '|' Expression # | Expression ('<' | '>' | '<=' | '>=') Expression # | Expression ('==' | '!=') Expression # | Expression '&&' Expression # | Expression '||' Expression # | Expression '?' Expression ':' Expression # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression # | PrimaryExpression # The AST naming does not follow the spec name = expression[caller_context.get_key()] is_compact_ast = caller_context.is_compact_ast if name == 'UnaryOperation': if is_compact_ast: attributes = expression else: attributes = expression['attributes'] assert 'prefix' in attributes operation_type = UnaryOperationType.get_type(attributes['operator'], attributes['prefix']) if is_compact_ast: expression = parse_expression(expression['subExpression'], caller_context) else: assert len(expression['children']) == 1 expression = parse_expression(expression['children'][0], caller_context) unary_op = UnaryOperation(expression, operation_type) return unary_op elif name == 'BinaryOperation': if is_compact_ast: attributes = expression else: attributes = expression['attributes'] operation_type = BinaryOperationType.get_type(attributes['operator']) if is_compact_ast: left_expression = parse_expression(expression['leftExpression'], caller_context) right_expression = parse_expression(expression['rightExpression'], caller_context) else: assert len(expression['children']) == 2 left_expression = parse_expression(expression['children'][0], caller_context) right_expression = parse_expression(expression['children'][1], caller_context) binary_op = BinaryOperation(left_expression, right_expression, operation_type) return binary_op elif name == 'FunctionCall': return parse_call(expression, caller_context) elif name == 'TupleExpression': """ For expression like (a,,c) = (1,2,3) the AST provides only two children in the left side We check the type provided (tuple(uint256,,uint256)) To determine that there is an empty variable Otherwhise we would not be able to determine that a = 1, c = 3, and 2 is lost Note: this is only possible with Solidity >= 0.4.12 """ if is_compact_ast: expressions = [ parse_expression(e, caller_context) if e else None for e in expression['components'] ] else: if 'children' not in expression: attributes = expression['attributes'] components = attributes['components'] expressions = [ parse_expression(c, caller_context) if c else None for c in components ] else: expressions = [ parse_expression(e, caller_context) for e in expression['children'] ] # Add none for empty tuple items if "attributes" in expression: if "type" in expression['attributes']: t = expression['attributes']['type'] if ',,' in t or '(,' in t or ',)' in t: t = t[len('tuple('):-1] elems = t.split(',') for idx in range(len(elems)): if elems[idx] == '': expressions.insert(idx, None) t = TupleExpression(expressions) return t elif name == 'Conditional': if is_compact_ast: if_expression = parse_expression(expression['condition'], caller_context) then_expression = parse_expression(expression['trueExpression'], caller_context) else_expression = parse_expression(expression['falseExpression'], caller_context) else: children = expression['children'] assert len(children) == 3 if_expression = parse_expression(children[0], caller_context) then_expression = parse_expression(children[1], caller_context) else_expression = parse_expression(children[2], caller_context) conditional = ConditionalExpression(if_expression, then_expression, else_expression) return conditional elif name == 'Assignment': if is_compact_ast: left_expression = parse_expression(expression['leftHandSide'], caller_context) right_expression = parse_expression(expression['rightHandSide'], caller_context) operation_type = AssignmentOperationType.get_type( expression['operator']) operation_return_type = expression['typeDescriptions'][ 'typeString'] else: attributes = expression['attributes'] children = expression['children'] assert len(expression['children']) == 2 left_expression = parse_expression(children[0], caller_context) right_expression = parse_expression(children[1], caller_context) operation_type = AssignmentOperationType.get_type( attributes['operator']) operation_return_type = attributes['type'] assignement = AssignmentOperation(left_expression, right_expression, operation_type, operation_return_type) return assignement elif name == 'Literal': assert 'children' not in expression if is_compact_ast: value = expression['value'] if not value: value = '0x' + expression['hexValue'] else: value = expression['attributes']['value'] if value is None: # for literal declared as hex # see https://solidity.readthedocs.io/en/v0.4.25/types.html?highlight=hex#hexadecimal-literals assert 'hexvalue' in expression['attributes'] value = '0x' + expression['attributes']['hexvalue'] literal = Literal(value) return literal elif name == 'Identifier': assert 'children' not in expression t = None if caller_context.is_compact_ast: value = expression['name'] t = expression['typeDescriptions']['typeString'] else: value = expression['attributes']['value'] if 'type' in expression['attributes']: t = expression['attributes']['type'] if t: found = re.findall( '[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)', t) assert len(found) <= 1 if found: value = value + '(' + found[0] + ')' value = filter_name(value) var = find_variable(value, caller_context) identifier = Identifier(var) return identifier elif name == 'IndexAccess': if is_compact_ast: index_type = expression['typeDescriptions']['typeString'] left = expression['baseExpression'] right = expression['indexExpression'] else: index_type = expression['attributes']['type'] children = expression['children'] assert len(children) == 2 left = children[0] right = children[1] left_expression = parse_expression(left, caller_context) right_expression = parse_expression(right, caller_context) index = IndexAccess(left_expression, right_expression, index_type) return index elif name == 'MemberAccess': if caller_context.is_compact_ast: member_name = expression['memberName'] member_type = expression['typeDescriptions']['typeString'] member_expression = parse_expression(expression['expression'], caller_context) else: member_name = expression['attributes']['member_name'] member_type = expression['attributes']['type'] children = expression['children'] assert len(children) == 1 member_expression = parse_expression(children[0], caller_context) if str(member_expression) == 'super': super_name = parse_super_name(expression, is_compact_ast) if isinstance(caller_context, Contract): inheritance = caller_context.inheritance else: assert isinstance(caller_context, Function) inheritance = caller_context.contract.inheritance var = None for father in inheritance: try: var = find_variable(super_name, father) break except VariableNotFound: continue if var is None: raise VariableNotFound( 'Variable not found: {}'.format(super_name)) return SuperIdentifier(var) member_access = MemberAccess(member_name, member_type, member_expression) if str(member_access) in SOLIDITY_VARIABLES_COMPOSED: return Identifier(SolidityVariableComposed(str(member_access))) return member_access elif name == 'ElementaryTypeNameExpression': # nop exression # uint; if is_compact_ast: value = expression['typeName'] else: assert 'children' not in expression value = expression['attributes']['value'] t = parse_type(UnknownType(value), caller_context) return ElementaryTypeNameExpression(t) # NewExpression is not a root expression, it's always the child of another expression elif name == 'NewExpression': if is_compact_ast: type_name = expression['typeName'] else: children = expression['children'] assert len(children) == 1 type_name = children[0] if type_name[caller_context.get_key()] == 'ArrayTypeName': depth = 0 while type_name[caller_context.get_key()] == 'ArrayTypeName': # Note: dont conserve the size of the array if provided # We compute it directly if is_compact_ast: type_name = type_name['baseType'] else: type_name = type_name['children'][0] depth += 1 if type_name[caller_context.get_key()] == 'ElementaryTypeName': if is_compact_ast: array_type = ElementaryType(type_name['name']) else: array_type = ElementaryType( type_name['attributes']['name']) elif type_name[caller_context.get_key()] == 'UserDefinedTypeName': if is_compact_ast: array_type = parse_type(UnknownType(type_name['name']), caller_context) else: array_type = parse_type( UnknownType(type_name['attributes']['name']), caller_context) else: logger.error('Incorrect type array {}'.format(type_name)) exit(-1) array = NewArray(depth, array_type) return array if type_name[caller_context.get_key()] == 'ElementaryTypeName': if is_compact_ast: elem_type = ElementaryType(type_name['name']) else: elem_type = ElementaryType(type_name['attributes']['name']) new_elem = NewElementaryType(elem_type) return new_elem assert type_name[caller_context.get_key()] == 'UserDefinedTypeName' if is_compact_ast: contract_name = type_name['name'] else: contract_name = type_name['attributes']['name'] new = NewContract(contract_name) return new elif name == 'ModifierInvocation': if is_compact_ast: called = parse_expression(expression['modifierName'], caller_context) arguments = [] if expression['arguments']: arguments = [ parse_expression(a, caller_context) for a in expression['arguments'] ] else: children = expression['children'] called = parse_expression(children[0], caller_context) arguments = [ parse_expression(a, caller_context) for a in children[1::] ] call = CallExpression(called, arguments, 'Modifier') return call logger.error('Expression not parsed %s' % name) exit(-1)
for ir in node.irs: if isinstance(ir, HighLevelCall): if ir.destination in taints: print('Call to tainted address found in {}'.format( function.name)) if __name__ == "__main__": if len(sys.argv) != 2: print('python taint_mapping.py taint.sol') exit(-1) # Init slither slither = Slither(sys.argv[1]) initial_taint = [SolidityVariableComposed('msg.sender')] initial_taint += [SolidityVariableComposed('msg.value')] KEY = 'TAINT' prev_taints = [] slither.context[KEY] = initial_taint while (set(prev_taints) != set(slither.context[KEY])): prev_taints = slither.context[KEY] for contract in slither.contracts: for function in contract.functions: print('Function {}'.format(function.name)) slither.context[KEY] = list( set(slither.context[KEY] + function.parameters)) visit_node(function.entry_point, []) print('All variables tainted : {}'.format(
class IncorrectStrictEquality(AbstractDetector): ARGUMENT = "incorrect-equality" HELP = "Dangerous strict equalities" IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.HIGH WIKI = ( "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-strict-equalities" ) WIKI_TITLE = "Dangerous strict equalities" WIKI_DESCRIPTION = "Use of strict equalities that can be easily manipulated by an attacker." WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Crowdsale{ function fund_reached() public returns(bool){ return this.balance == 100 ether; } ``` `Crowdsale` relies on `fund_reached` to know when to stop the sale of tokens. `Crowdsale` reaches 100 Ether. Bob sends 0.1 Ether. As a result, `fund_reached` is always false and the `crowdsale` never ends.""" WIKI_RECOMMENDATION = ( """Don't use strict equality to determine if an account has enough Ether or tokens.""" ) sources_taint = [ SolidityVariable("now"), SolidityVariableComposed("block.number"), SolidityVariableComposed("block.timestamp"), ] @staticmethod def is_direct_comparison(ir): return isinstance(ir, Binary) and ir.type == BinaryType.EQUAL @staticmethod def is_any_tainted(variables, taints, function) -> bool: return any((is_dependent_ssa(var, taint, function.contract) for var in variables for taint in taints)) def taint_balance_equalities(self, functions): taints = [] for func in functions: for node in func.nodes: for ir in node.irs_ssa: if isinstance(ir, Balance): taints.append(ir.lvalue) if isinstance(ir, HighLevelCall): # print(ir.function.full_name) if (isinstance(ir.function, Function) and ir.function.full_name == "balanceOf(address)"): taints.append(ir.lvalue) if (isinstance(ir.function, StateVariable) and isinstance(ir.function.type, MappingType) and ir.function.name == "balanceOf" and ir.function.type.type_from == ElementaryType("address") and ir.function.type.type_to == ElementaryType("uint256")): taints.append(ir.lvalue) if isinstance(ir, Assignment): if ir.rvalue in self.sources_taint: taints.append(ir.lvalue) return taints # Retrieve all tainted (node, function) pairs def tainted_equality_nodes(self, funcs, taints): results = dict() taints += self.sources_taint for func in funcs: for node in func.nodes: for ir in node.irs_ssa: # Filter to only tainted equality (==) comparisons if self.is_direct_comparison(ir) and self.is_any_tainted( ir.used, taints, func): if func not in results: results[func] = [] results[func].append(node) return results def detect_strict_equality(self, contract): funcs = contract.all_functions_called + contract.modifiers # Taint all BALANCE accesses taints = self.taint_balance_equalities(funcs) # Accumulate tainted (node,function) pairs involved in strict equality (==) comparisons results = self.tainted_equality_nodes(funcs, taints) return results def _detect(self): results = [] for c in self.compilation_unit.contracts_derived: ret = self.detect_strict_equality(c) # sort ret to get deterministic results ret = sorted(list(ret.items()), key=lambda x: x[0].name) for f, nodes in ret: func_info = [f, " uses a dangerous strict equality:\n"] # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) # Output each node with the function info header as a separate result. for node in nodes: node_info = func_info + ["\t- ", node, "\n"] res = self.generate_result(node_info) results.append(res) return results
def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expression": # pylint: disable=too-many-nested-blocks,too-many-statements """ Returns: str: expression """ # Expression # = Expression ('++' | '--') # | NewExpression # | IndexAccess # | MemberAccess # | FunctionCall # | '(' Expression ')' # | ('!' | '~' | 'delete' | '++' | '--' | '+' | '-') Expression # | Expression '**' Expression # | Expression ('*' | '/' | '%') Expression # | Expression ('+' | '-') Expression # | Expression ('<<' | '>>') Expression # | Expression '&' Expression # | Expression '^' Expression # | Expression '|' Expression # | Expression ('<' | '>' | '<=' | '>=') Expression # | Expression ('==' | '!=') Expression # | Expression '&&' Expression # | Expression '||' Expression # | Expression '?' Expression ':' Expression # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression # | PrimaryExpression # The AST naming does not follow the spec name = expression[caller_context.get_key()] is_compact_ast = caller_context.is_compact_ast src = expression["src"] if name == "UnaryOperation": if is_compact_ast: attributes = expression else: attributes = expression["attributes"] assert "prefix" in attributes operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"]) if is_compact_ast: expression = parse_expression(expression["subExpression"], caller_context) else: assert len(expression["children"]) == 1 expression = parse_expression(expression["children"][0], caller_context) unary_op = UnaryOperation(expression, operation_type) unary_op.set_offset(src, caller_context.slither) return unary_op if name == "BinaryOperation": if is_compact_ast: attributes = expression else: attributes = expression["attributes"] operation_type = BinaryOperationType.get_type(attributes["operator"]) if is_compact_ast: left_expression = parse_expression(expression["leftExpression"], caller_context) right_expression = parse_expression(expression["rightExpression"], caller_context) else: assert len(expression["children"]) == 2 left_expression = parse_expression(expression["children"][0], caller_context) right_expression = parse_expression(expression["children"][1], caller_context) binary_op = BinaryOperation(left_expression, right_expression, operation_type) binary_op.set_offset(src, caller_context.slither) return binary_op if name in "FunctionCall": return parse_call(expression, caller_context) if name == "FunctionCallOptions": # call/gas info are handled in parse_call if is_compact_ast: called = parse_expression(expression["expression"], caller_context) else: called = parse_expression(expression["children"][0], caller_context) assert isinstance(called, (MemberAccess, NewContract, Identifier, TupleExpression)) return called if name == "TupleExpression": # For expression like # (a,,c) = (1,2,3) # the AST provides only two children in the left side # We check the type provided (tuple(uint256,,uint256)) # To determine that there is an empty variable # Otherwhise we would not be able to determine that # a = 1, c = 3, and 2 is lost # # Note: this is only possible with Solidity >= 0.4.12 if is_compact_ast: expressions = [ parse_expression(e, caller_context) if e else None for e in expression["components"] ] else: if "children" not in expression: attributes = expression["attributes"] components = attributes["components"] expressions = [ parse_expression(c, caller_context) if c else None for c in components ] else: expressions = [parse_expression(e, caller_context) for e in expression["children"]] # Add none for empty tuple items if "attributes" in expression: if "type" in expression["attributes"]: t = expression["attributes"]["type"] if ",," in t or "(," in t or ",)" in t: t = t[len("tuple(") : -1] elems = t.split(",") for idx, _ in enumerate(elems): if elems[idx] == "": expressions.insert(idx, None) t = TupleExpression(expressions) t.set_offset(src, caller_context.slither) return t if name == "Conditional": if is_compact_ast: if_expression = parse_expression(expression["condition"], caller_context) then_expression = parse_expression(expression["trueExpression"], caller_context) else_expression = parse_expression(expression["falseExpression"], caller_context) else: children = expression["children"] assert len(children) == 3 if_expression = parse_expression(children[0], caller_context) then_expression = parse_expression(children[1], caller_context) else_expression = parse_expression(children[2], caller_context) conditional = ConditionalExpression(if_expression, then_expression, else_expression) conditional.set_offset(src, caller_context.slither) return conditional if name == "Assignment": if is_compact_ast: left_expression = parse_expression(expression["leftHandSide"], caller_context) right_expression = parse_expression(expression["rightHandSide"], caller_context) operation_type = AssignmentOperationType.get_type(expression["operator"]) operation_return_type = expression["typeDescriptions"]["typeString"] else: attributes = expression["attributes"] children = expression["children"] assert len(expression["children"]) == 2 left_expression = parse_expression(children[0], caller_context) right_expression = parse_expression(children[1], caller_context) operation_type = AssignmentOperationType.get_type(attributes["operator"]) operation_return_type = attributes["type"] assignement = AssignmentOperation( left_expression, right_expression, operation_type, operation_return_type ) assignement.set_offset(src, caller_context.slither) return assignement if name == "Literal": subdenomination = None assert "children" not in expression if is_compact_ast: value = expression["value"] if value: if "subdenomination" in expression and expression["subdenomination"]: subdenomination = expression["subdenomination"] elif not value and value != "": value = "0x" + expression["hexValue"] type_candidate = expression["typeDescriptions"]["typeString"] # Length declaration for array was None until solc 0.5.5 if type_candidate is None: if expression["kind"] == "number": type_candidate = "int_const" else: value = expression["attributes"]["value"] if value: if ( "subdenomination" in expression["attributes"] and expression["attributes"]["subdenomination"] ): subdenomination = expression["attributes"]["subdenomination"] elif value is None: # for literal declared as hex # see https://solidity.readthedocs.io/en/v0.4.25/types.html?highlight=hex#hexadecimal-literals assert "hexvalue" in expression["attributes"] value = "0x" + expression["attributes"]["hexvalue"] type_candidate = expression["attributes"]["type"] if type_candidate is None: if value.isdecimal(): type_candidate = ElementaryType("uint256") else: type_candidate = ElementaryType("string") elif type_candidate.startswith("int_const "): type_candidate = ElementaryType("uint256") elif type_candidate.startswith("bool"): type_candidate = ElementaryType("bool") elif type_candidate.startswith("address"): type_candidate = ElementaryType("address") else: type_candidate = ElementaryType("string") literal = Literal(value, type_candidate, subdenomination) literal.set_offset(src, caller_context.slither) return literal if name == "Identifier": assert "children" not in expression t = None if caller_context.is_compact_ast: value = expression["name"] t = expression["typeDescriptions"]["typeString"] else: value = expression["attributes"]["value"] if "type" in expression["attributes"]: t = expression["attributes"]["type"] if t: found = re.findall("[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)", t) assert len(found) <= 1 if found: value = value + "(" + found[0] + ")" value = filter_name(value) if "referencedDeclaration" in expression: referenced_declaration = expression["referencedDeclaration"] else: referenced_declaration = None var = find_variable(value, caller_context, referenced_declaration) identifier = Identifier(var) identifier.set_offset(src, caller_context.slither) return identifier if name == "IndexAccess": if is_compact_ast: index_type = expression["typeDescriptions"]["typeString"] left = expression["baseExpression"] right = expression.get("indexExpression", None) else: index_type = expression["attributes"]["type"] children = expression["children"] left = children[0] right = children[1] if len(children) > 1 else None # IndexAccess is used to describe ElementaryTypeNameExpression # if abi.decode is used # For example, abi.decode(data, ...(uint[]) ) if right is None: ret = parse_expression(left, caller_context) # Nested array are not yet available in abi.decode if isinstance(ret, ElementaryTypeNameExpression): old_type = ret.type ret.type = ArrayType(old_type, None) return ret left_expression = parse_expression(left, caller_context) right_expression = parse_expression(right, caller_context) index = IndexAccess(left_expression, right_expression, index_type) index.set_offset(src, caller_context.slither) return index if name == "MemberAccess": if caller_context.is_compact_ast: member_name = expression["memberName"] member_type = expression["typeDescriptions"]["typeString"] # member_type = parse_type( # UnknownType(expression["typeDescriptions"]["typeString"]), caller_context # ) member_expression = parse_expression(expression["expression"], caller_context) else: member_name = expression["attributes"]["member_name"] member_type = expression["attributes"]["type"] # member_type = parse_type(UnknownType(expression["attributes"]["type"]), caller_context) children = expression["children"] assert len(children) == 1 member_expression = parse_expression(children[0], caller_context) if str(member_expression) == "super": super_name = parse_super_name(expression, is_compact_ast) var = find_variable(super_name, caller_context, is_super=True) if var is None: raise VariableNotFound("Variable not found: {}".format(super_name)) sup = SuperIdentifier(var) sup.set_offset(src, caller_context.slither) return sup member_access = MemberAccess(member_name, member_type, member_expression) member_access.set_offset(src, caller_context.slither) if str(member_access) in SOLIDITY_VARIABLES_COMPOSED: id_idx = Identifier(SolidityVariableComposed(str(member_access))) id_idx.set_offset(src, caller_context.slither) return id_idx return member_access if name == "ElementaryTypeNameExpression": return _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context) # NewExpression is not a root expression, it's always the child of another expression if name == "NewExpression": if is_compact_ast: type_name = expression["typeName"] else: children = expression["children"] assert len(children) == 1 type_name = children[0] if type_name[caller_context.get_key()] == "ArrayTypeName": depth = 0 while type_name[caller_context.get_key()] == "ArrayTypeName": # Note: dont conserve the size of the array if provided # We compute it directly if is_compact_ast: type_name = type_name["baseType"] else: type_name = type_name["children"][0] depth += 1 if type_name[caller_context.get_key()] == "ElementaryTypeName": if is_compact_ast: array_type = ElementaryType(type_name["name"]) else: array_type = ElementaryType(type_name["attributes"]["name"]) elif type_name[caller_context.get_key()] == "UserDefinedTypeName": if is_compact_ast: array_type = parse_type(UnknownType(type_name["name"]), caller_context) else: array_type = parse_type( UnknownType(type_name["attributes"]["name"]), caller_context ) elif type_name[caller_context.get_key()] == "FunctionTypeName": array_type = parse_type(type_name, caller_context) else: raise ParsingError("Incorrect type array {}".format(type_name)) array = NewArray(depth, array_type) array.set_offset(src, caller_context.slither) return array if type_name[caller_context.get_key()] == "ElementaryTypeName": if is_compact_ast: elem_type = ElementaryType(type_name["name"]) else: elem_type = ElementaryType(type_name["attributes"]["name"]) new_elem = NewElementaryType(elem_type) new_elem.set_offset(src, caller_context.slither) return new_elem assert type_name[caller_context.get_key()] == "UserDefinedTypeName" if is_compact_ast: # Changed introduced in Solidity 0.8 # see https://github.com/crytic/slither/issues/794 # TODO explore more the changes introduced in 0.8 and the usage of pathNode/IdentifierPath if "name" not in type_name: assert "pathNode" in type_name and "name" in type_name["pathNode"] contract_name = type_name["pathNode"]["name"] else: contract_name = type_name["name"] else: contract_name = type_name["attributes"]["name"] new = NewContract(contract_name) new.set_offset(src, caller_context.slither) return new if name == "ModifierInvocation": if is_compact_ast: called = parse_expression(expression["modifierName"], caller_context) arguments = [] if expression.get("arguments", None): arguments = [parse_expression(a, caller_context) for a in expression["arguments"]] else: children = expression["children"] called = parse_expression(children[0], caller_context) arguments = [parse_expression(a, caller_context) for a in children[1::]] call = CallExpression(called, arguments, "Modifier") call.set_offset(src, caller_context.slither) return call if name == "IndexRangeAccess": # For now, we convert array slices to a direct array access # As a result the generated IR will lose the slices information # As far as I understand, array slice are only used in abi.decode # https://solidity.readthedocs.io/en/v0.6.12/types.html # TODO: Investigate array slices usage and implication for the IR base = parse_expression(expression["baseExpression"], caller_context) return base # Introduced with solc 0.8 if name == "IdentifierPath": if caller_context.is_compact_ast: value = expression["name"] if "referencedDeclaration" in expression: referenced_declaration = expression["referencedDeclaration"] else: referenced_declaration = None var = find_variable(value, caller_context, referenced_declaration) identifier = Identifier(var) identifier.set_offset(src, caller_context.slither) return identifier raise ParsingError("IdentifierPath not currently supported for the legacy ast") raise ParsingError("Expression not parsed %s" % name)
class IncorrectStrictEquality(AbstractDetector): ARGUMENT = 'incorrect-equality' HELP = 'Dangerous strict equalities' IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.HIGH WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#dangerous-strict-equalities' WIKI_TITLE = 'Dangerous strict equalities' WIKI_DESCRIPTION = 'Use of strick equalities that can be easily manipulated by an attacker.' WIKI_EXPLOIT_SCENARIO = ''' ```solidity contract Crowdsale{ function fund_reached() public returns(bool){ return this.balance == 100 ether; } ``` `Crowdsale` relies on `fund_reached` to know when to stop the sale of tokens. `Crowdsale` reaches 100 ether. Bob sends 0.1 ether. As a result, `fund_reached` is always false and the crowdsale never ends.''' WIKI_RECOMMENDATION = '''Don't use strict equality to determine if an account has enough ethers or tokens.''' sources_taint = [ SolidityVariable('now'), SolidityVariableComposed('block.number'), SolidityVariableComposed('block.timestamp') ] @staticmethod def is_direct_comparison(ir): return isinstance(ir, Binary) and ir.type == BinaryType.EQUAL @staticmethod def is_any_tainted(variables, taints, function): return any([ is_dependent_ssa(var, taint, function.contract) for var in variables for taint in taints ]) def taint_balance_equalities(self, functions): taints = [] for func in functions: for node in func.nodes: for ir in node.irs_ssa: if isinstance(ir, Balance): taints.append(ir.lvalue) if isinstance(ir, HighLevelCall): #print(ir.function.full_name) if isinstance(ir.function, Function) and\ ir.function.full_name == 'balanceOf(address)': taints.append(ir.lvalue) if isinstance(ir.function, StateVariable) and\ isinstance(ir.function.type, MappingType) and\ ir.function.name == 'balanceOf' and\ ir.function.type.type_from == ElementaryType('address') and\ ir.function.type.type_to == ElementaryType('uint256'): taints.append(ir.lvalue) if isinstance(ir, Assignment): if ir.rvalue in self.sources_taint: taints.append(ir.lvalue) return taints # Retrieve all tainted (node, function) pairs def tainted_equality_nodes(self, funcs, taints): results = dict() taints += self.sources_taint for func in funcs: for node in func.nodes: for ir in node.irs_ssa: # Filter to only tainted equality (==) comparisons if self.is_direct_comparison(ir) and self.is_any_tainted( ir.used, taints, func): if func not in results: results[func] = [] results[func].append(node) return results def detect_strict_equality(self, contract): funcs = contract.all_functions_called + contract.modifiers # Taint all BALANCE accesses taints = self.taint_balance_equalities(funcs) # Accumulate tainted (node,function) pairs involved in strict equality (==) comparisons results = self.tainted_equality_nodes(funcs, taints) return results def detect(self): results = [] for c in self.slither.contracts_derived: ret = self.detect_strict_equality(c) info = '' # sort ret to get deterministic results ret = sorted(list(ret.items()), key=lambda x: x[0].name) for f, nodes in ret: info += "{}.{} ({}) uses a dangerous strict equality:\n".format( f.contract.name, f.name, f.source_mapping_str) # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) for node in nodes: info += "\t- {}\n".format(str(node.expression)) json = self.generate_json_result(info) self.add_function_to_json(f, json) self.add_nodes_to_json(nodes, json) results.append(json) if info: self.log(info) return results
class IncorrectStrictEquality(AbstractDetector): ARGUMENT = 'incorrect-equality' HELP = 'Dangerous strict equalities' IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.HIGH WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#dangerous-strict-equalities' sources_taint = [SolidityVariable('now'), SolidityVariableComposed('block.number'), SolidityVariableComposed('block.timestamp')] @staticmethod def is_direct_comparison(ir): return isinstance(ir, Binary) and ir.type == BinaryType.EQUAL @staticmethod def is_any_tainted(variables, taints, function): return any([is_dependent_ssa(var, taint, function.contract) for var in variables for taint in taints]) def taint_balance_equalities(self, functions): taints = [] for func in functions: for node in func.nodes: for ir in node.irs_ssa: if isinstance(ir, Balance): taints.append(ir.lvalue) if isinstance(ir, HighLevelCall): #print(ir.function.full_name) if isinstance(ir.function, Function) and\ ir.function.full_name == 'balanceOf(address)': taints.append(ir.lvalue) if isinstance(ir.function, StateVariable) and\ isinstance(ir.function.type, MappingType) and\ ir.function.name == 'balanceOf' and\ ir.function.type.type_from == ElementaryType('address') and\ ir.function.type.type_to == ElementaryType('uint256'): taints.append(ir.lvalue) if isinstance(ir, Assignment): if ir.rvalue in self.sources_taint: taints.append(ir.lvalue) return taints # Retrieve all tainted (node, function) pairs def tainted_equality_nodes(self, funcs, taints): results = dict() taints += self.sources_taint for func in funcs: for node in func.nodes: for ir in node.irs_ssa: # Filter to only tainted equality (==) comparisons if self.is_direct_comparison(ir) and self.is_any_tainted(ir.used, taints, func): if func not in results: results[func] = [] results[func].append(node) return results def detect_strict_equality(self, contract): funcs = contract.all_functions_called + contract.modifiers # Taint all BALANCE accesses taints = self.taint_balance_equalities(funcs) # Accumulate tainted (node,function) pairs involved in strict equality (==) comparisons results = self.tainted_equality_nodes(funcs, taints) return results def detect(self): results = [] for c in self.slither.contracts_derived: ret = self.detect_strict_equality(c) info = '' # sort ret to get deterministic results ret = sorted(list(ret.items()), key=lambda x:x[0].name) for f, nodes in ret: info += "{}.{} ({}) uses a dangerous strict equality:\n".format(f.contract.name, f.name, f.source_mapping_str) # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) for node in nodes: info += "\t- {}\n".format(str(node.expression)) json = self.generate_json_result(info) self.add_function_to_json(f, json) self.add_nodes_to_json(nodes, json) results.append(json) if info: self.log(info) return results