def _assert_reason(self, test_expr, msg): if isinstance(msg, sri_ast.Name) and msg.id == 'UNREACHABLE': return self._assert_unreachable(test_expr, msg) if not isinstance(msg, sri_ast.Str): raise StructureException( 'Reason parameter of assert needs to be a literal string ' '(or UNREACHABLE constant).', msg) if len(msg.s.strip()) == 0: raise StructureException('Empty reason string not allowed.', self.stmt) reason_str = msg.s.strip() sig_placeholder = self.context.new_placeholder(BaseType(32)) arg_placeholder = self.context.new_placeholder(BaseType(32)) reason_str_type = ByteArrayType(len(reason_str)) placeholder_bytes = Expr(msg, self.context).lll_node method_id = fourbytes_to_int(keccak256(b"Error(string)")[:4]) assert_reason = [ 'seq', ['mstore', sig_placeholder, method_id], ['mstore', arg_placeholder, 32], placeholder_bytes, [ 'assert_reason', test_expr, int(sig_placeholder + 28), int(4 + get_size_of_type(reason_str_type) * 32), ], ] return LLLnode.from_list(assert_reason, typ=None, pos=getpos(self.stmt))
def get_target(self, target): # Check if we are doing assignment of an iteration loop. if isinstance(target, sri_ast.Subscript) and self.context.in_for_loop: raise_exception = False if isinstance(target.value, sri_ast.Attribute): list_name = f"{target.value.value.id}.{target.value.attr}" if list_name in self.context.in_for_loop: raise_exception = True if isinstance(target.value, sri_ast.Name) and \ target.value.id in self.context.in_for_loop: list_name = target.value.id raise_exception = True if raise_exception: raise StructureException( f"Altering list '{list_name}' which is being iterated!", self.stmt, ) if isinstance(target, sri_ast.Name) and target.id in self.context.forvars: raise StructureException( f"Altering iterator '{target.id}' which is in use!", self.stmt, ) if isinstance(target, sri_ast.Tuple): target = Expr(target, self.context).lll_node for node in target.args: constancy_checks(node, self.context, self.stmt) return target target = Expr.parse_variable_location(target, self.context) constancy_checks(target, self.context, self.stmt) return target
def validate_call_args(node: sri_ast.Call, arg_count: Union[int, tuple], kwargs: Optional[list] = None) -> None: """ Validate positional and keyword arguments of a Call node. This function does not handle type checking of arguments, it only checks correctness of the number of arguments given and keyword names. Arguments --------- node : Call srilang ast Call node to be validated. arg_count : int | tuple The required number of positional arguments. When given as a tuple the value is interpreted as the minimum and maximum number of arguments. kwargs : list, optional A list of valid keyword arguments. When arg_count is a tuple and the number of positional arguments exceeds the minimum, the excess values are considered to fill the first values on this list. Returns ------- None. Raises an exception when the arguments are invalid. """ if kwargs is None: kwargs = [] if not isinstance(node, sri_ast.Call): raise StructureException("Expected Call", node) if not isinstance(arg_count, (int, tuple)): raise CompilerPanic( f"Invalid type for arg_count: {type(arg_count).__name__}") if isinstance(arg_count, int) and len(node.args) != arg_count: raise ArgumentException( f"Invalid argument count: expected {arg_count}, got {len(node.args)}", node) elif (isinstance(arg_count, tuple) and not arg_count[0] <= len(node.args) <= arg_count[1]): raise ArgumentException( f"Invalid argument count: expected between " f"{arg_count[0]} and {arg_count[1]}, got {len(node.args)}", node, ) if not kwargs and node.keywords: raise ArgumentException("Keyword arguments are not accepted here", node.keywords[0]) for key in node.keywords: if key.arg is None: raise StructureException("Use of **kwargs is not supported", key.value) if key.arg not in kwargs: raise ArgumentException(f"Invalid keyword argument '{key.arg}'", key) if (isinstance(arg_count, tuple) and kwargs.index(key.arg) < len(node.args) - arg_count[0]): raise ArgumentException( f"'{key.arg}' was given as a positional argument", key)
def check_valid_contract_interface(global_ctx, contract_sigs): # the check for private function collisions is made to prevent future # breaking changes if we switch to internal calls (@iamdefinitelyahuman) func_sigs = [sig for sig in contract_sigs.values() if isinstance(sig, FunctionSignature)] func_conflicts = find_signature_conflicts(func_sigs) if len(func_conflicts) > 0: sig_1, sig_2 = func_conflicts[0] raise StructureException( f'Methods {sig_1.sig} and {sig_2.sig} have conflicting IDs ' f'(id {sig_1.method_id})', sig_1.func_ast_code, ) if global_ctx._interface: funcs_left = global_ctx._interface.copy() for sig, func_sig in contract_sigs.items(): if isinstance(func_sig, FunctionSignature): if func_sig.private: # private functions are not defined within interfaces continue if sig not in funcs_left: # this function is not present within the interface continue clean_sig_output_type = func_sig.output_type if _compare_outputs(funcs_left[sig].output_type, clean_sig_output_type): del funcs_left[sig] if isinstance(func_sig, EventSignature) and func_sig.sig in funcs_left: del funcs_left[func_sig.sig] if funcs_left: error_message = 'Contract does not comply to supplied Interface(s).\n' missing_functions = [ str(func_sig) for sig_name, func_sig in funcs_left.items() if isinstance(func_sig, FunctionSignature) ] missing_events = [ sig_name for sig_name, func_sig in funcs_left.items() if isinstance(func_sig, EventSignature) ] if missing_functions: err_join = "\n\t".join(missing_functions) error_message += f'Missing interface functions:\n\t{err_join}' if missing_events: err_join = "\n\t".join(missing_events) error_message += f'Missing interface events:\n\t{err_join}' raise StructureException(error_message)
def _check_return_body(node, node_list): return_count = len([n for n in node_list if is_return_from_function(n)]) if return_count > 1: raise StructureException( f'Too too many exit statements (return, raise or selfdestruct).', node) # Check for invalid code after returns. last_node_pos = len(node_list) - 1 for idx, n in enumerate(node_list): if is_return_from_function(n) and idx < last_node_pos: # is not last statement in body. raise StructureException( 'Exit statement with succeeding code (that will not execute).', node_list[idx + 1])
def call(self): from srilang.functions import ( DISPATCH_TABLE, ) if isinstance(self.expr.func, sri_ast.Name): function_name = self.expr.func.id if function_name in DISPATCH_TABLE: return DISPATCH_TABLE[function_name].build_LLL(self.expr, self.context) # Struct constructors do not need `self` prefix. elif function_name in self.context.structs: args = self.expr.args if len(args) != 1: raise StructureException( "Struct constructor is called with one argument only", self.expr, ) arg = args[0] if not isinstance(arg, sri_ast.Dict): raise TypeMismatch( "Struct can only be constructed with a dict", self.expr, ) return Expr.struct_literals(arg, function_name, self.context) # Contract assignment. Bar(<address>). elif function_name in self.context.sigs: ret, arg_lll = self._is_valid_contract_assign() if ret is True: arg_lll.typ = ContractType(function_name) # Cast to Correct contract type. return arg_lll else: raise TypeMismatch( "ContractType definition expects one address argument.", self.expr, ) else: err_msg = f"Not a top-level function: {function_name}" if function_name in [x.split('(')[0] for x, _ in self.context.sigs['self'].items()]: err_msg += f". Did you mean self.{function_name}?" raise StructureException(err_msg, self.expr) elif isinstance(self.expr.func, sri_ast.Attribute) and isinstance(self.expr.func.value, sri_ast.Name) and self.expr.func.value.id == "self": # noqa: E501 return self_call.make_call(self.expr, self.context) else: return external_call.make_external_call(self.expr, self.context)
def boolean_operations(self): # Iterate through values for value in self.expr.values: # Check for calls at assignment if self.context.in_assignment and isinstance(value, sri_ast.Call): raise StructureException( "Boolean operations with calls may not be performed on assignment", self.expr, ) # Check for boolean operations with non-boolean inputs _expr = Expr.parse_value_expr(value, self.context) if not is_base_type(_expr.typ, 'bool'): raise TypeMismatch( "Boolean operations can only be between booleans!", self.expr, ) # TODO: Handle special case of literals and simplify at compile time # Check for valid ops if isinstance(self.expr.op, sri_ast.And): op = 'and' elif isinstance(self.expr.op, sri_ast.Or): op = 'or' else: raise Exception("Unsupported bool op: " + self.expr.op) # Handle different numbers of inputs count = len(self.expr.values) if count < 2: raise StructureException("Expected at least two arguments for a bool op", self.expr) elif count == 2: left = Expr.parse_value_expr(self.expr.values[0], self.context) right = Expr.parse_value_expr(self.expr.values[1], self.context) return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) else: left = Expr.parse_value_expr(self.expr.values[0], self.context) right = Expr.parse_value_expr(self.expr.values[1], self.context) p = ['seq', [op, left, right]] values = self.expr.values[2:] while len(values) > 0: value = Expr.parse_value_expr(values[0], self.context) p = [op, value, p] values = values[1:] return LLLnode.from_list(p, typ='bool', pos=getpos(self.expr))
def __init__(self, stmt, context): self.stmt = stmt self.context = context self.stmt_table = { sri_ast.Expr: self.expr, sri_ast.Pass: self.parse_pass, sri_ast.AnnAssign: self.ann_assign, sri_ast.Assign: self.assign, sri_ast.If: self.parse_if, sri_ast.Call: self.call, sri_ast.Assert: self.parse_assert, sri_ast.For: self.parse_for, sri_ast.AugAssign: self.aug_assign, sri_ast.Break: self.parse_break, sri_ast.Continue: self.parse_continue, sri_ast.Return: self.parse_return, sri_ast.Name: self.parse_name, sri_ast.Raise: self.parse_raise, } stmt_type = self.stmt.__class__ if stmt_type in self.stmt_table: self.lll_node = self.stmt_table[stmt_type]() else: raise StructureException( f"Unsupported statement type: {type(stmt).__name__}", stmt)
def list_literals(self): if not len(self.expr.elts): raise StructureException("List must have elements", self.expr) def get_out_type(lll_node): if isinstance(lll_node, ListType): return get_out_type(lll_node.subtype) return lll_node.typ o = [] previous_type = None out_type = None for elt in self.expr.elts: current_lll_node = Expr(elt, self.context).lll_node if not out_type: out_type = current_lll_node.typ current_type = get_out_type(current_lll_node) if len(o) > 0 and previous_type != current_type: raise TypeMismatch("Lists may only contain one type", self.expr) else: o.append(current_lll_node) previous_type = current_type return LLLnode.from_list( ["multi"] + o, typ=ListType(out_type, len(o)), pos=getpos(self.expr), )
def unary_operations(self): operand = Expr.parse_value_expr(self.expr.operand, self.context) if isinstance(self.expr.op, sri_ast.Not): if isinstance(operand.typ, BaseType) and operand.typ.typ == 'bool': return LLLnode.from_list(["iszero", operand], typ='bool', pos=getpos(self.expr)) else: raise TypeMismatch( f"Only bool is supported for not operation, {operand.typ} supplied.", self.expr, ) elif isinstance(self.expr.op, sri_ast.USub): if not is_numeric_type(operand.typ): raise TypeMismatch( f"Unsupported type for negation: {operand.typ}", self.expr, ) # Clamp on minimum integer value as we cannot negate that value # (all other integer values are fine) min_int_val = get_min_val_for_type(operand.typ.typ) return LLLnode.from_list( ["sub", 0, ["clampgt", operand, min_int_val]], typ=operand.typ, pos=getpos(self.expr) ) else: raise StructureException("Only the 'not' or 'neg' unary operators are supported")
def tuple_literals(self): if not len(self.expr.elts): raise StructureException("Tuple must have elements", self.expr) o = [] for elt in self.expr.elts: o.append(Expr(elt, self.context).lll_node) typ = TupleType([x.typ for x in o], is_literal=True) return LLLnode.from_list(["multi"] + o, typ=typ, pos=getpos(self.expr))
def extract_file_interface_imports(code: SourceCode) -> InterfaceImports: ast_tree = sri_ast.parse_to_ast(code) imports_dict: InterfaceImports = {} for item in ast_tree: if isinstance(item, sri_ast.Import): # type: ignore for a_name in item.names: # type: ignore if not a_name.asname: raise StructureException( 'Interface statement requires an accompanying `as` statement.', item, ) if a_name.asname in imports_dict: raise StructureException( f'Interface with alias {a_name.asname} already exists', item, ) imports_dict[a_name.asname] = a_name.name.replace('.', '/') elif isinstance(item, sri_ast.ImportFrom): # type: ignore for a_name in item.names: # type: ignore if a_name.asname: raise StructureException("From imports cannot use aliases", item) level = item.level # type: ignore module = item.module or "" # type: ignore if not level and module == 'srilang.interfaces': continue base_path = "" if level > 1: base_path = "../" * (level-1) elif level == 1: base_path = "./" base_path = f"{base_path}{module.replace('.','/')}/" for a_name in item.names: # type: ignore if a_name.name in imports_dict: raise StructureException( f'Interface with name {a_name.name} already exists', item, ) imports_dict[a_name.name] = f"{base_path}{a_name.name}" return imports_dict
def add_constant(self, item, global_ctx): args = item.annotation.args if not item.value: raise StructureException('Constants must express a value!', item) is_correctly_formatted_struct = (len(args) == 1 and isinstance( args[0], (sri_ast.Subscript, sri_ast.Name, sri_ast.Call))) and item.target if is_correctly_formatted_struct: c_name = item.target.id if global_ctx.is_valid_varname(c_name, item): self._constants[c_name] = self.unroll_constant( item, global_ctx) self._constants_ast[c_name] = item.value # TODO: the previous `if` has no else which will result in this # *silently* existing without doing anything. is this intended # behavior. else: raise StructureException('Incorrectly formatted struct', item)
def g(element, context): function_name = element.func.id if len(element.args) > len(argz): raise StructureException( f"Expected {len(argz)} arguments for {function_name}, " f"got {len(element.args)}", element ) subs = [] for i, expected_arg in enumerate(argz): if len(element.args) > i: subs.append(process_arg( i + 1, element.args[i], expected_arg, function_name, context, )) elif isinstance(expected_arg, Optional): subs.append(expected_arg.default) else: raise StructureException( f"Not enough arguments for function: {element.func.id}", element ) kwsubs = {} element_kw = {k.arg: k.value for k in element.keywords} for k, expected_arg in kwargz.items(): if k not in element_kw: if isinstance(expected_arg, Optional): kwsubs[k] = expected_arg.default else: raise StructureException(f"Function {function_name} requires argument {k}", element) else: kwsubs[k] = process_arg(k, element_kw[k], expected_arg, function_name, context) for k, _arg in element_kw.items(): if k not in kwargz: raise StructureException(f"Unexpected argument: {k}", element) return f(element, subs, kwsubs, context)
def decorator_fn(self, node, context): argz = [i[1] for i in self._inputs] kwargz = getattr(self, "_kwargs", {}) function_name = node.func.id if len(node.args) > len(argz): raise StructureException( f"Expected {len(argz)} arguments for {function_name}, got {len(node.args)}", node ) subs = [] for i, expected_arg in enumerate(argz): if len(node.args) > i: subs.append(process_arg( i + 1, node.args[i], expected_arg, function_name, context, )) elif isinstance(expected_arg, Optional): subs.append(expected_arg.default) else: raise StructureException( f"Not enough arguments for function: {node.func.id}", node ) kwsubs = {} node_kw = {k.arg: k.value for k in node.keywords} for k, expected_arg in kwargz.items(): if k not in node_kw: if not isinstance(expected_arg, Optional): raise StructureException( f"Function {function_name} requires argument {k}", node ) kwsubs[k] = expected_arg.default else: kwsubs[k] = process_arg(k, node_kw[k], expected_arg, function_name, context) for k, _arg in node_kw.items(): if k not in kwargz: raise StructureException(f"Unexpected argument: {k}", node) return wrapped_fn(self, node, subs, kwsubs, context)
def abi_type_to_ast(atype, expected_size): if atype in ('int128', 'uint256', 'bool', 'address', 'bytes32'): return sri_ast.Name(id=atype) elif atype == 'fixed168x10': return sri_ast.Name(id='decimal') elif atype in ('bytes', 'string'): # expected_size is the maximum length for inputs, minimum length for outputs return sri_ast.Subscript( value=sri_ast.Name(id=atype), slice=sri_ast.Index(value=sri_ast.Int(value=expected_size)) ) else: raise StructureException(f'Type {atype} not supported by srilang.')
def _check_valid_range_constant(self, arg_ast_node, raise_exception=True): with self.context.range_scope(): # TODO should catch if raise_exception == False? arg_expr = Expr.parse_value_expr(arg_ast_node, self.context) is_integer_literal = (isinstance(arg_expr.typ, BaseType) and arg_expr.typ.is_literal and arg_expr.typ.typ in {'uint256', 'int128'}) if not is_integer_literal and raise_exception: raise StructureException( "Range only accepts literal (constant) values of type uint256 or int128", arg_ast_node) return is_integer_literal, arg_expr
def make_call(stmt_expr, context): method_name, _, sig = call_lookup_specs(stmt_expr, context) if context.is_constant() and not sig.const: raise ConstancyViolation( f"May not call non-constant function '{method_name}' within {context.pp_constancy()}.", getpos(stmt_expr)) if not sig.private: raise StructureException("Cannot call public functions via 'self'", stmt_expr) return call_self_private(stmt_expr, context, sig)
def convert(expr, context): if len(expr.args) != 2: raise StructureException( 'The convert function expects two parameters.', expr) if isinstance(expr.args[1], sri_ast.Str): warnings.warn( "String parameter has been removed (see VIP1026). " "Use a srilang type instead.", DeprecationWarning) if isinstance(expr.args[1], sri_ast.Name): output_type = expr.args[1].id elif (isinstance(expr.args[1], (sri_ast.Subscript)) and isinstance(expr.args[1].value, (sri_ast.Name))): output_type = expr.args[1].value.id else: raise StructureException( "Invalid conversion type, use valid srilang type.", expr) if output_type in CONVERSION_TABLE: return CONVERSION_TABLE[output_type](expr, context) else: raise StructureException(f"Conversion to {output_type} is invalid.", expr)
def ann_assign(self): with self.context.assignment_scope(): typ = parse_type( self.stmt.annotation, location='memory', custom_structs=self.context.structs, constants=self.context.constants, ) if isinstance(self.stmt.target, sri_ast.Attribute): raise TypeMismatch( f'May not set type for field {self.stmt.target.attr}', self.stmt, ) varname = self.stmt.target.id pos = self.context.new_variable(varname, typ) if self.stmt.value is None: raise StructureException( 'New variables must be initialized explicitly', self.stmt) sub = Expr(self.stmt.value, self.context).lll_node is_literal_bytes32_assign = (isinstance(sub.typ, ByteArrayType) and sub.typ.maxlen == 32 and isinstance(typ, BaseType) and typ.typ == 'bytes32' and sub.typ.is_literal) # If bytes[32] to bytes32 assignment rewrite sub as bytes32. if is_literal_bytes32_assign: sub = LLLnode( bytes_to_int(self.stmt.value.s), typ=BaseType('bytes32'), pos=getpos(self.stmt), ) self._check_valid_assign(sub) self._check_same_variable_assign(sub) variable_loc = LLLnode.from_list( pos, typ=typ, location='memory', pos=getpos(self.stmt), ) o = make_setter(variable_loc, sub, 'memory', pos=getpos(self.stmt)) return o
def subscript(self): sub = Expr.parse_variable_location(self.expr.value, self.context) if isinstance(sub.typ, (MappingType, ListType)): if not isinstance(self.expr.slice, sri_ast.Index): raise StructureException( "Array access must access a single element, not a slice", self.expr, ) index = Expr.parse_value_expr(self.expr.slice.value, self.context) elif isinstance(sub.typ, TupleType): if not isinstance(self.expr.slice.value, sri_ast.Int) or self.expr.slice.value.n < 0 or self.expr.slice.value.n >= len(sub.typ.members): # noqa: E501 raise TypeMismatch("Tuple index invalid", self.expr.slice.value) index = self.expr.slice.value.n else: raise TypeMismatch("Bad subscript attempt", self.expr.value) o = add_variable_offset(sub, index, pos=getpos(self.expr)) o.mutable = sub.mutable return o
def parse_external_contracts(external_contracts, global_ctx): for _contractname in global_ctx._contracts: _contract_defs = global_ctx._contracts[_contractname] _defnames = [_def.name for _def in _contract_defs] contract = {} if len(set(_defnames)) < len(_contract_defs): raise FunctionDeclarationException( "Duplicate function name: " f"{[name for name in _defnames if _defnames.count(name) > 1][0]}" ) for _def in _contract_defs: constant = False # test for valid call type keyword. if len(_def.body) == 1 and \ isinstance(_def.body[0], sri_ast.Expr) and \ isinstance(_def.body[0].value, sri_ast.Name) and \ _def.body[0].value.id in ('modifying', 'constant'): constant = True if _def.body[ 0].value.id == 'constant' else False else: raise StructureException( 'constant or modifying call type must be specified', _def) # Recognizes already-defined structs sig = FunctionSignature.from_definition( _def, contract_def=True, constant_override=constant, custom_structs=global_ctx._structs, constants=global_ctx._constants) contract[sig.name] = sig external_contracts[_contractname] = contract for interface_name, interface in global_ctx._interfaces.items(): external_contracts[interface_name] = { sig.name: sig for sig in interface if isinstance(sig, FunctionSignature) } return external_contracts
def check_unmatched_return(fn_node): if fn_node.returns and not _return_check(fn_node.body): raise StructureException( f'Missing or Unmatched return statements in function "{fn_node.name}". ' 'All control flow statements (like if) need balanced return statements.', fn_node)
def pack_arguments(signature, args, context, stmt_expr, return_placeholder=True): pos = getpos(stmt_expr) placeholder_typ = ByteArrayType( maxlen=sum([get_size_of_type(arg.typ) for arg in signature.args]) * 32 + 32) placeholder = context.new_placeholder(placeholder_typ) setters = [['mstore', placeholder, signature.method_id]] needpos = False staticarray_offset = 0 expected_arg_count = len(signature.args) actual_arg_count = len(args) if actual_arg_count != expected_arg_count: raise StructureException( f"Wrong number of args for: {signature.name} " f"({actual_arg_count} args given, expected {expected_arg_count}", stmt_expr) for i, (arg, typ) in enumerate(zip(args, [arg.typ for arg in signature.args])): if isinstance(typ, BaseType): setters.append( make_setter(LLLnode.from_list( placeholder + staticarray_offset + 32 + i * 32, typ=typ, ), arg, 'memory', pos=pos, in_function_call=True)) elif isinstance(typ, ByteArrayLike): setters.append([ 'mstore', placeholder + staticarray_offset + 32 + i * 32, '_poz' ]) arg_copy = LLLnode.from_list('_s', typ=arg.typ, location=arg.location) target = LLLnode.from_list( ['add', placeholder + 32, '_poz'], typ=typ, location='memory', ) setters.append([ 'with', '_s', arg, [ 'seq', make_byte_array_copier(target, arg_copy, pos), [ 'set', '_poz', [ 'add', 32, ['ceil32', ['add', '_poz', get_length(arg_copy)]] ] ], ], ]) needpos = True elif isinstance(typ, (StructType, ListType)): if has_dynamic_data(typ): raise TypeMismatch("Cannot pack bytearray in struct", stmt_expr) target = LLLnode.from_list( [placeholder + 32 + staticarray_offset + i * 32], typ=typ, location='memory', ) setters.append(make_setter(target, arg, 'memory', pos=pos)) if (isinstance(typ, ListType)): count = typ.count else: count = len(typ.tuple_items()) staticarray_offset += 32 * (count - 1) else: raise TypeMismatch(f"Cannot pack argument of type {typ}", stmt_expr) # For private call usage, doesn't use a returner. returner = [[placeholder + 28]] if return_placeholder else [] if needpos: return (LLLnode.from_list([ 'with', '_poz', len(args) * 32 + staticarray_offset, ['seq'] + setters + returner ], typ=placeholder_typ, location='memory'), placeholder_typ.maxlen - 28, placeholder + 32) else: return (LLLnode.from_list(['seq'] + setters + returner, typ=placeholder_typ, location='memory'), placeholder_typ.maxlen - 28, placeholder + 32)
def arithmetic(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.right, self.context) if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): raise TypeMismatch( f"Unsupported types for arithmetic op: {left.typ} {right.typ}", self.expr, ) arithmetic_pair = {left.typ.typ, right.typ.typ} pos = getpos(self.expr) # Special case with uint256 were int literal may be casted. if arithmetic_pair == {'uint256', 'int128'}: # Check right side literal. if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value): right = LLLnode.from_list( right.value, typ=BaseType('uint256', None, is_literal=True), pos=pos, ) # Check left side literal. elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value): left = LLLnode.from_list( left.value, typ=BaseType('uint256', None, is_literal=True), pos=pos, ) if left.typ.typ == "decimal" and isinstance(self.expr.op, sri_ast.Pow): raise TypeMismatch( "Cannot perform exponentiation on decimal values.", self.expr, ) # Only allow explicit conversions to occur. if left.typ.typ != right.typ.typ: raise TypeMismatch( f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.", self.expr, ) ltyp, rtyp = left.typ.typ, right.typ.typ if isinstance(self.expr.op, (sri_ast.Add, sri_ast.Sub)): new_typ = BaseType(ltyp) op = 'add' if isinstance(self.expr.op, sri_ast.Add) else 'sub' if ltyp == 'uint256' and isinstance(self.expr.op, sri_ast.Add): # safeadd arith = ['seq', ['assert', ['ge', ['add', 'l', 'r'], 'l']], ['add', 'l', 'r']] elif ltyp == 'uint256' and isinstance(self.expr.op, sri_ast.Sub): # safesub arith = ['seq', ['assert', ['ge', 'l', 'r']], ['sub', 'l', 'r']] elif ltyp == rtyp: arith = [op, 'l', 'r'] else: raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'") elif isinstance(self.expr.op, sri_ast.Mult): new_typ = BaseType(ltyp) if ltyp == rtyp == 'uint256': arith = ['with', 'ans', ['mul', 'l', 'r'], ['seq', ['assert', ['or', ['eq', ['div', 'ans', 'l'], 'r'], ['iszero', 'l']]], 'ans']] elif ltyp == rtyp == 'int128': # TODO should this be 'smul' (note edge cases in YP for smul) arith = ['mul', 'l', 'r'] elif ltyp == rtyp == 'decimal': # TODO should this be smul arith = ['with', 'ans', ['mul', 'l', 'r'], ['seq', ['assert', ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['iszero', 'l']]], ['sdiv', 'ans', DECIMAL_DIVISOR]]] else: raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'") elif isinstance(self.expr.op, sri_ast.Div): if right.typ.is_literal and right.value == 0: raise ZeroDivisionException("Cannot divide by 0.", self.expr) new_typ = BaseType(ltyp) if ltyp == rtyp == 'uint256': arith = ['div', 'l', ['clamp_nonzero', 'r']] elif ltyp == rtyp == 'int128': arith = ['sdiv', 'l', ['clamp_nonzero', 'r']] elif ltyp == rtyp == 'decimal': arith = ['sdiv', # TODO check overflow cases, also should it be smul ['mul', 'l', DECIMAL_DIVISOR], ['clamp_nonzero', 'r']] else: raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'") elif isinstance(self.expr.op, sri_ast.Mod): if right.typ.is_literal and right.value == 0: raise ZeroDivisionException("Cannot calculate modulus of 0.", self.expr) new_typ = BaseType(ltyp) if ltyp == rtyp == 'uint256': arith = ['mod', 'l', ['clamp_nonzero', 'r']] elif ltyp == rtyp: # TODO should this be regular mod arith = ['smod', 'l', ['clamp_nonzero', 'r']] else: raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'") elif isinstance(self.expr.op, sri_ast.Pow): if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, sri_ast.Name): raise TypeMismatch( "Cannot use dynamic values as exponents, for unit base types", self.expr, ) new_typ = BaseType(ltyp) if ltyp == rtyp == 'uint256': arith = ['seq', ['assert', ['or', # r == 1 | iszero(r) # could be simplified to ~(r & 1) ['or', ['eq', 'r', 1], ['iszero', 'r']], ['lt', 'l', ['exp', 'l', 'r']]]], ['exp', 'l', 'r']] elif ltyp == rtyp == 'int128': arith = ['exp', 'l', 'r'] else: raise TypeMismatch('Only whole number exponents are supported', self.expr) else: raise StructureException(f"Unsupported binary operator: {self.expr.op}", self.expr) p = ['seq'] if new_typ.typ == 'int128': p.append([ 'clamp', ['mload', MemoryPositions.MINNUM], arith, ['mload', MemoryPositions.MAXNUM], ]) elif new_typ.typ == 'decimal': p.append([ 'clamp', ['mload', MemoryPositions.MINDECIMAL], arith, ['mload', MemoryPositions.MAXDECIMAL], ]) elif new_typ.typ == 'uint256': p.append(arith) else: raise Exception(f"{arith} {new_typ}") p = ['with', 'l', left, ['with', 'r', right, p]] return LLLnode.from_list(p, typ=new_typ, pos=pos)
def from_definition(cls, code, sigs=None, custom_structs=None, contract_def=False, constants=None, constant_override=False): if not custom_structs: custom_structs = {} name = code.name mem_pos = 0 valid_name, msg = is_varname_valid(name, custom_structs, constants) if not valid_name and (not name.lower() in FUNCTION_WHITELIST): raise FunctionDeclarationException("Function name invalid. " + msg, code) # Validate default values. for default_value in getattr(code.args, 'defaults', []): validate_default_values(default_value) # Determine the arguments, expects something of the form def foo(arg1: # int128, arg2: int128 ... args = [] for arg in code.args.args: # Each arg needs a type specified. typ = arg.annotation if not typ: raise InvalidType("Argument must have type", arg) # Validate arg name. check_valid_varname( arg.arg, custom_structs, constants, arg, "Argument name invalid or reserved. ", FunctionDeclarationException, ) # Check for duplicate arg name. if arg.arg in (x.name for x in args): raise FunctionDeclarationException( "Duplicate function argument name: " + arg.arg, arg, ) parsed_type = parse_type( typ, None, sigs, custom_structs=custom_structs, constants=constants, ) args.append( VariableRecord( arg.arg, mem_pos, parsed_type, False, defined_at=getpos(arg), )) if isinstance(parsed_type, ByteArrayLike): mem_pos += 32 else: mem_pos += get_size_of_type(parsed_type) * 32 const = constant_override payable = False private = False public = False nonreentrant_key = '' # Update function properties from decorators for dec in code.decorator_list: if isinstance(dec, sri_ast.Name) and dec.id == "constant": const = True elif isinstance(dec, sri_ast.Name) and dec.id == "payable": payable = True elif isinstance(dec, sri_ast.Name) and dec.id == "private": private = True elif isinstance(dec, sri_ast.Name) and dec.id == "public": public = True elif isinstance(dec, sri_ast.Call) and dec.func.id == "nonreentrant": if nonreentrant_key: raise StructureException( "Only one @nonreentrant decorator allowed per function", dec) if dec.args and len(dec.args) == 1 and isinstance( dec.args[0], sri_ast.Str) and dec.args[0].s: # noqa: E501 nonreentrant_key = dec.args[0].s else: raise StructureException( "@nonreentrant decorator requires a non-empty string to use as a key.", dec) else: raise StructureException("Bad decorator", dec) if public and private: raise StructureException( f"Cannot use public and private decorators on the same function: {name}" ) if payable and const: raise StructureException( f"Function {name} cannot be both constant and payable.") if payable and private: raise StructureException( f"Function {name} cannot be both private and payable.") if (not public and not private) and not contract_def: raise StructureException( "Function visibility must be declared (@public or @private)", code, ) if const and nonreentrant_key: raise StructureException( "@nonreentrant makes no sense on a @constant function.", code) # Determine the return type and whether or not it's constant. Expects something # of the form: # def foo(): ... # def foo() -> int128: ... # If there is no return type, ie. it's of the form def foo(): ... # and NOT def foo() -> type: ..., then it's null if not code.returns: output_type = None elif isinstance(code.returns, (sri_ast.Name, sri_ast.Compare, sri_ast.Subscript, sri_ast.Call, sri_ast.Tuple)): output_type = parse_type( code.returns, None, sigs, custom_structs=custom_structs, constants=constants, ) else: raise InvalidType( f"Output type invalid or unsupported: {parse_type(code.returns, None)}", code.returns, ) # Output type must be canonicalizable if output_type is not None: assert isinstance(output_type, TupleType) or canonicalize_type(output_type) # Get the canonical function signature sig = cls.get_full_sig(name, code.args.args, sigs, custom_structs, constants) # Take the first 4 bytes of the hash of the sig to get the method ID method_id = fourbytes_to_int(keccak256(bytes(sig, 'utf-8'))[:4]) return cls(name, args, output_type, const, payable, private, nonreentrant_key, sig, method_id, code)
def compare(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.right, self.context) if right.value is None: raise InvalidLiteral( 'Comparison to None is not allowed, compare against a default value.', self.expr, ) if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike): # TODO: Can this if branch be removed ^ pass elif isinstance(self.expr.op, sri_ast.In) and isinstance(right.typ, ListType): if left.typ != right.typ.subtype: raise TypeMismatch( "Can't use IN comparison with different types!", self.expr, ) return self.build_in_comparator() if isinstance(self.expr.op, sri_ast.Gt): op = 'sgt' elif isinstance(self.expr.op, sri_ast.GtE): op = 'sge' elif isinstance(self.expr.op, sri_ast.LtE): op = 'sle' elif isinstance(self.expr.op, sri_ast.Lt): op = 'slt' elif isinstance(self.expr.op, sri_ast.Eq): op = 'eq' elif isinstance(self.expr.op, sri_ast.NotEq): op = 'ne' else: raise Exception("Unsupported comparison operator") # Compare (limited to 32) byte arrays. if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike): left = Expr(self.expr.left, self.context).lll_node right = Expr(self.expr.right, self.context).lll_node length_mismatch = (left.typ.maxlen != right.typ.maxlen) left_over_32 = left.typ.maxlen > 32 right_over_32 = right.typ.maxlen > 32 if length_mismatch or left_over_32 or right_over_32: left_keccak = keccak256_helper(self.expr, [left], None, self.context) right_keccak = keccak256_helper(self.expr, [right], None, self.context) if op == 'eq' or op == 'ne': return LLLnode.from_list( [op, left_keccak, right_keccak], typ='bool', pos=getpos(self.expr), ) else: raise StructureException( "Can only compare strings/bytes of length shorter", " than 32 bytes other than equality comparisons", self.expr, ) else: def load_bytearray(side): if side.location == 'memory': return ['mload', ['add', 32, side]] elif side.location == 'storage': return ['sload', ['add', 1, ['sha3_32', side]]] return LLLnode.from_list( [op, load_bytearray(left), load_bytearray(right)], typ='bool', pos=getpos(self.expr), ) # Compare other types. if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): if op not in ('eq', 'ne'): raise TypeMismatch("Invalid type for comparison op", self.expr) left_type, right_type = left.typ.typ, right.typ.typ # Special Case: comparison of a literal integer. If in valid range allow it to be compared. if {left_type, right_type} == {'int128', 'uint256'} and {left.typ.is_literal, right.typ.is_literal} == {True, False}: # noqa: E501 comparison_allowed = False if left.typ.is_literal and SizeLimits.in_bounds(right_type, left.value): comparison_allowed = True elif right.typ.is_literal and SizeLimits.in_bounds(left_type, right.value): comparison_allowed = True op = self._signed_to_unsigned_comparision_op(op) if comparison_allowed: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) elif {left_type, right_type} == {'uint256', 'uint256'}: op = self._signed_to_unsigned_comparision_op(op) elif (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type: # noqa: E501 raise TypeMismatch( f'Implicit conversion from {left_type} to {right_type} disallowed, please convert.', self.expr, ) if left_type == right_type: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) else: raise TypeMismatch( f"Unsupported types for comparison: {left_type} {right_type}", self.expr, )
def external_contract_call(node, context, contract_name, contract_address, pos, value=None, gas=None): from srilang.parser.expr import ( Expr, ) if value is None: value = 0 if gas is None: gas = 'gas' if not contract_name: raise StructureException( f'Invalid external contract call "{node.func.attr}".', node) if contract_name not in context.sigs: raise VariableDeclarationException( f'Contract "{contract_name}" not declared yet', node) if contract_address.value == "address": raise StructureException(f"External calls to self are not permitted.", node) method_name = node.func.attr if method_name not in context.sigs[contract_name]: raise FunctionDeclarationException(( f"Function not declared yet: {method_name} (reminder: " "function must be declared in the correct contract)" f"The available methods are: {','.join(context.sigs[contract_name].keys())}" ), node.func) sig = context.sigs[contract_name][method_name] inargs, inargsize, _ = pack_arguments( sig, [Expr(arg, context).lll_node for arg in node.args], context, node.func, ) output_placeholder, output_size, returner = get_external_contract_call_output( sig, context) sub = [ 'seq', ['assert', ['extcodesize', contract_address]], ['assert', ['ne', 'address', contract_address]], ] if context.is_constant() and not sig.const: raise ConstancyViolation( f"May not call non-constant function '{method_name}' within {context.pp_constancy()}." " For asserting the result of modifiable contract calls, try assert_modifiable.", node) if context.is_constant() or sig.const: sub.append([ 'assert', [ 'staticcall', gas, contract_address, inargs, inargsize, output_placeholder, output_size, ] ]) else: sub.append([ 'assert', [ 'call', gas, contract_address, value, inargs, inargsize, output_placeholder, output_size, ] ]) sub.extend(returner) o = LLLnode.from_list(sub, typ=sig.output_type, location='memory', pos=getpos(node)) return o
def make_external_call(stmt_expr, context): from srilang.parser.expr import Expr value, gas = get_external_contract_keywords(stmt_expr, context) if (isinstance(stmt_expr.func, sri_ast.Attribute) and isinstance(stmt_expr.func.value, sri_ast.Call)): contract_name = stmt_expr.func.value.func.id contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0], context) return external_contract_call( stmt_expr, context, contract_name, contract_address, pos=getpos(stmt_expr), value=value, gas=gas, ) elif isinstance( stmt_expr.func.value, sri_ast.Attribute ) and stmt_expr.func.value.attr in context.sigs: # noqa: E501 contract_name = stmt_expr.func.value.attr var = context.globals[stmt_expr.func.value.attr] contract_address = unwrap_location( LLLnode.from_list( var.pos, typ=var.typ, location='storage', pos=getpos(stmt_expr), annotation='self.' + stmt_expr.func.value.attr, )) return external_contract_call( stmt_expr, context, contract_name, contract_address, pos=getpos(stmt_expr), value=value, gas=gas, ) elif (isinstance(stmt_expr.func.value, sri_ast.Attribute) and stmt_expr.func.value.attr in context.globals and hasattr(context.globals[stmt_expr.func.value.attr].typ, 'name')): contract_name = context.globals[stmt_expr.func.value.attr].typ.name var = context.globals[stmt_expr.func.value.attr] contract_address = unwrap_location( LLLnode.from_list( var.pos, typ=var.typ, location='storage', pos=getpos(stmt_expr), annotation='self.' + stmt_expr.func.value.attr, )) return external_contract_call( stmt_expr, context, contract_name, contract_address, pos=getpos(stmt_expr), value=value, gas=gas, ) else: raise StructureException("Unsupported operator.", stmt_expr)
def call_self_private(stmt_expr, context, sig): # ** Private Call ** # Steps: # (x) push current local variables # (x) push arguments # (x) push jumpdest (callback ptr) # (x) jump to label # (x) pop return values # (x) pop local variables method_name, expr_args, sig = call_lookup_specs(stmt_expr, context) pre_init = [] pop_local_vars = [] push_local_vars = [] pop_return_values = [] push_args = [] # Push local variables. var_slots = [(v.pos, v.size) for name, v in context.vars.items() if v.location == 'memory'] if var_slots: var_slots.sort(key=lambda x: x[0]) mem_from, mem_to = var_slots[0][ 0], var_slots[-1][0] + var_slots[-1][1] * 32 i_placeholder = context.new_placeholder(BaseType('uint256')) local_save_ident = f"_{stmt_expr.lineno}_{stmt_expr.col_offset}" push_loop_label = 'save_locals_start' + local_save_ident pop_loop_label = 'restore_locals_start' + local_save_ident if mem_to - mem_from > 320: push_local_vars = [['mstore', i_placeholder, mem_from], ['label', push_loop_label], ['mload', ['mload', i_placeholder]], [ 'mstore', i_placeholder, ['add', ['mload', i_placeholder], 32] ], [ 'if', ['lt', ['mload', i_placeholder], mem_to], ['goto', push_loop_label] ]] pop_local_vars = [['mstore', i_placeholder, mem_to - 32], ['label', pop_loop_label], ['mstore', ['mload', i_placeholder], 'pass'], [ 'mstore', i_placeholder, ['sub', ['mload', i_placeholder], 32] ], [ 'if', ['ge', ['mload', i_placeholder], mem_from], ['goto', pop_loop_label] ]] else: push_local_vars = [['mload', pos] for pos in range(mem_from, mem_to, 32)] pop_local_vars = [['mstore', pos, 'pass'] for pos in range(mem_to - 32, mem_from - 32, -32) ] # Push Arguments if expr_args: inargs, inargsize, arg_pos = pack_arguments( sig, expr_args, context, stmt_expr, return_placeholder=False, ) push_args += [ inargs ] # copy arguments first, to not mess up the push/pop sequencing. static_arg_size = 32 * sum( [get_static_size_of_type(arg.typ) for arg in expr_args]) static_pos = int(arg_pos + static_arg_size) needs_dyn_section = any( [has_dynamic_data(arg.typ) for arg in expr_args]) if needs_dyn_section: ident = f'push_args_{sig.method_id}_{stmt_expr.lineno}_{stmt_expr.col_offset}' start_label = ident + '_start' end_label = ident + '_end' i_placeholder = context.new_placeholder(BaseType('uint256')) # Calculate copy start position. # Given | static | dynamic | section in memory, # copy backwards so the values are in order on the stack. # We calculate i, the end of the whole encoded part # (i.e. the starting index for copy) # by taking ceil32(len<arg>) + offset<arg> + arg_pos # for the last dynamic argument and arg_pos is the start # the whole argument section. idx = 0 for arg in expr_args: if isinstance(arg.typ, ByteArrayLike): last_idx = idx idx += get_static_size_of_type(arg.typ) push_args += [[ 'with', 'offset', ['mload', arg_pos + last_idx * 32], [ 'with', 'len_pos', ['add', arg_pos, 'offset'], [ 'with', 'len_value', ['mload', 'len_pos'], [ 'mstore', i_placeholder, ['add', 'len_pos', ['ceil32', 'len_value']] ] ] ] ]] # loop from end of dynamic section to start of dynamic section, # pushing each element onto the stack. push_args += [ ['label', start_label], [ 'if', ['lt', ['mload', i_placeholder], static_pos], ['goto', end_label] ], ['mload', ['mload', i_placeholder]], [ 'mstore', i_placeholder, ['sub', ['mload', i_placeholder], 32] ], # decrease i ['goto', start_label], ['label', end_label] ] # push static section push_args += [['mload', pos] for pos in reversed(range(arg_pos, static_pos, 32))] elif sig.args: raise StructureException( f"Wrong number of args for: {sig.name} (0 args given, expected {len(sig.args)})", stmt_expr) # Jump to function label. jump_to_func = [ ['add', ['pc'], 6], # set callback pointer. ['goto', f'priv_{sig.method_id}'], ['jumpdest'], ] # Pop return values. returner = [0] if sig.output_type: output_placeholder, returner, output_size = call_make_placeholder( stmt_expr, context, sig) if output_size > 0: dynamic_offsets = [] if isinstance(sig.output_type, (BaseType, ListType)): pop_return_values = [[ 'mstore', ['add', output_placeholder, pos], 'pass' ] for pos in range(0, output_size, 32)] elif isinstance(sig.output_type, ByteArrayLike): dynamic_offsets = [(0, sig.output_type)] pop_return_values = [ ['pop', 'pass'], ] elif isinstance(sig.output_type, TupleLike): static_offset = 0 pop_return_values = [] for name, typ in sig.output_type.tuple_items(): if isinstance(typ, ByteArrayLike): pop_return_values.append([ 'mstore', ['add', output_placeholder, static_offset], 'pass' ]) dynamic_offsets.append(([ 'mload', ['add', output_placeholder, static_offset] ], name)) static_offset += 32 else: member_output_size = get_size_of_type(typ) * 32 pop_return_values.extend([[ 'mstore', ['add', output_placeholder, pos], 'pass' ] for pos in range(static_offset, static_offset + member_output_size, 32)]) static_offset += member_output_size # append dynamic unpacker. dyn_idx = 0 for in_memory_offset, _out_type in dynamic_offsets: ident = f"{stmt_expr.lineno}_{stmt_expr.col_offset}_arg_{dyn_idx}" dyn_idx += 1 start_label = 'dyn_unpack_start_' + ident end_label = 'dyn_unpack_end_' + ident i_placeholder = context.new_placeholder( typ=BaseType('uint256')) begin_pos = ['add', output_placeholder, in_memory_offset] # loop until length. o = LLLnode.from_list( [ 'seq_unchecked', ['mstore', begin_pos, 'pass'], # get len ['mstore', i_placeholder, 0], ['label', start_label], [ # break 'if', [ 'ge', ['mload', i_placeholder], ['ceil32', ['mload', begin_pos]] ], ['goto', end_label] ], [ # pop into correct memory slot. 'mstore', [ 'add', ['add', begin_pos, 32], ['mload', i_placeholder] ], 'pass', ], # increment i [ 'mstore', i_placeholder, ['add', 32, ['mload', i_placeholder]] ], ['goto', start_label], ['label', end_label] ], typ=None, annotation='dynamic unpacker', pos=getpos(stmt_expr)) pop_return_values.append(o) call_body = list( itertools.chain( ['seq_unchecked'], pre_init, push_local_vars, push_args, jump_to_func, pop_return_values, pop_local_vars, [returner], )) # If we have no return, we need to pop off pop_returner_call_body = ['pop', call_body ] if sig.output_type is None else call_body o = LLLnode.from_list(pop_returner_call_body, typ=sig.output_type, location='memory', pos=getpos(stmt_expr), annotation=f'Internal Call: {method_name}', add_gas_estimate=sig.gas) o.gas += sig.gas return o