def get_target(self, target): # Check if we are doing assignment of an iteration loop. if isinstance(target, sri_ast.Subscript) and self.context.in_for_loop: raise_exception = False if isinstance(target.value, sri_ast.Attribute): list_name = f"{target.value.value.id}.{target.value.attr}" if list_name in self.context.in_for_loop: raise_exception = True if isinstance(target.value, sri_ast.Name) and \ target.value.id in self.context.in_for_loop: list_name = target.value.id raise_exception = True if raise_exception: raise StructureException( f"Altering list '{list_name}' which is being iterated!", self.stmt, ) if isinstance(target, sri_ast.Name) and target.id in self.context.forvars: raise StructureException( f"Altering iterator '{target.id}' which is in use!", self.stmt, ) if isinstance(target, sri_ast.Tuple): target = Expr(target, self.context).lll_node for node in target.args: constancy_checks(node, self.context, self.stmt) return target target = Expr.parse_variable_location(target, self.context) constancy_checks(target, self.context, self.stmt) return target
def get_external_contract_keywords(stmt_expr, context): from srilang.parser.expr import Expr value, gas = None, None for kw in stmt_expr.keywords: if kw.arg not in ('value', 'gas'): raise TypeMismatch( 'Invalid keyword argument, only "gas" and "value" supported.', stmt_expr, ) elif kw.arg == 'gas': gas = Expr.parse_value_expr(kw.value, context) elif kw.arg == 'value': value = Expr.parse_value_expr(kw.value, context) return value, gas
def _assert_reason(self, test_expr, msg): if isinstance(msg, sri_ast.Name) and msg.id == 'UNREACHABLE': return self._assert_unreachable(test_expr, msg) if not isinstance(msg, sri_ast.Str): raise StructureException( 'Reason parameter of assert needs to be a literal string ' '(or UNREACHABLE constant).', msg) if len(msg.s.strip()) == 0: raise StructureException('Empty reason string not allowed.', self.stmt) reason_str = msg.s.strip() sig_placeholder = self.context.new_placeholder(BaseType(32)) arg_placeholder = self.context.new_placeholder(BaseType(32)) reason_str_type = ByteArrayType(len(reason_str)) placeholder_bytes = Expr(msg, self.context).lll_node method_id = fourbytes_to_int(keccak256(b"Error(string)")[:4]) assert_reason = [ 'seq', ['mstore', sig_placeholder, method_id], ['mstore', arg_placeholder, 32], placeholder_bytes, [ 'assert_reason', test_expr, int(sig_placeholder + 28), int(4 + get_size_of_type(reason_str_type) * 32), ], ] return LLLnode.from_list(assert_reason, typ=None, pos=getpos(self.stmt))
def _check_valid_range_constant(self, arg_ast_node, raise_exception=True): with self.context.range_scope(): # TODO should catch if raise_exception == False? arg_expr = Expr.parse_value_expr(arg_ast_node, self.context) is_integer_literal = (isinstance(arg_expr.typ, BaseType) and arg_expr.typ.is_literal and arg_expr.typ.typ in {'uint256', 'int128'}) if not is_integer_literal and raise_exception: raise StructureException( "Range only accepts literal (constant) values of type uint256 or int128", arg_ast_node) return is_integer_literal, arg_expr
def parse_assert(self): with self.context.assertion_scope(): test_expr = Expr.parse_value_expr(self.stmt.test, self.context) if not self.is_bool_expr(test_expr): raise TypeMismatch('Only boolean expressions allowed', self.stmt.test) if self.stmt.msg: return self._assert_reason(test_expr, self.stmt.msg) else: return LLLnode.from_list(['assert', test_expr], typ=None, pos=getpos(self.stmt))
def unroll_constant(self, const, global_ctx): ann_expr = None expr = Expr.parse_value_expr( const.value, Context(vars=None, global_ctx=global_ctx, origcode=const.full_source_code, memory_allocator=MemoryAllocator()), ) annotation_type = global_ctx.parse_type(const.annotation.args[0], None) fail = False if is_instances([expr.typ, annotation_type], ByteArrayType): if expr.typ.maxlen < annotation_type.maxlen: return const fail = True elif expr.typ != annotation_type: fail = True # special case for literals, which can be uint256 types as well. is_special_case_uint256_literal = (is_instances( [expr.typ, annotation_type], BaseType)) and ([ annotation_type.typ, expr.typ.typ ] == ['uint256', 'int128']) and SizeLimits.in_bounds( 'uint256', expr.value) is_special_case_int256_literal = (is_instances( [expr.typ, annotation_type], BaseType)) and ([ annotation_type.typ, expr.typ.typ ] == ['int128', 'int128']) and SizeLimits.in_bounds( 'int128', expr.value) if is_special_case_uint256_literal or is_special_case_int256_literal: fail = False if fail: raise TypeMismatch( f"Invalid value for constant type, expected {annotation_type} got " f"{expr.typ} instead", const.value, ) ann_expr = copy.deepcopy(expr) ann_expr.typ = annotation_type ann_expr.typ.is_literal = expr.typ.is_literal # Annotation type doesn't have literal set. return ann_expr
def get_constant(self, const_name, context): """ Return unrolled const """ # check if value is compatible with const = self._constants[const_name] if isinstance(const, sri_ast.AnnAssign): # Handle ByteArrays. if context: expr = Expr(const.value, context).lll_node return expr else: raise VariableDeclarationException( f"ByteArray: Can not be used outside of a function context: {const_name}" ) # Other types are already unwrapped, no need return self._constants[const_name]
def ann_assign(self): with self.context.assignment_scope(): typ = parse_type( self.stmt.annotation, location='memory', custom_structs=self.context.structs, constants=self.context.constants, ) if isinstance(self.stmt.target, sri_ast.Attribute): raise TypeMismatch( f'May not set type for field {self.stmt.target.attr}', self.stmt, ) varname = self.stmt.target.id pos = self.context.new_variable(varname, typ) if self.stmt.value is None: raise StructureException( 'New variables must be initialized explicitly', self.stmt) sub = Expr(self.stmt.value, self.context).lll_node is_literal_bytes32_assign = (isinstance(sub.typ, ByteArrayType) and sub.typ.maxlen == 32 and isinstance(typ, BaseType) and typ.typ == 'bytes32' and sub.typ.is_literal) # If bytes[32] to bytes32 assignment rewrite sub as bytes32. if is_literal_bytes32_assign: sub = LLLnode( bytes_to_int(self.stmt.value.s), typ=BaseType('bytes32'), pos=getpos(self.stmt), ) self._check_valid_assign(sub) self._check_same_variable_assign(sub) variable_loc = LLLnode.from_list( pos, typ=typ, location='memory', pos=getpos(self.stmt), ) o = make_setter(variable_loc, sub, 'memory', pos=getpos(self.stmt)) return o
def parse_if(self): if self.stmt.orelse: block_scope_id = id(self.stmt.orelse) with self.context.make_blockscope(block_scope_id): add_on = [parse_body(self.stmt.orelse, self.context)] else: add_on = [] block_scope_id = id(self.stmt) with self.context.make_blockscope(block_scope_id): test_expr = Expr.parse_value_expr(self.stmt.test, self.context) if not self.is_bool_expr(test_expr): raise TypeMismatch('Only boolean expressions allowed', self.stmt.test) body = ['if', test_expr, parse_body(self.stmt.body, self.context)] + add_on o = LLLnode.from_list(body, typ=None, pos=getpos(self.stmt)) return o
def assign(self): # Assignment (e.g. x[4] = y) with self.context.assignment_scope(): sub = Expr(self.stmt.value, self.context).lll_node # Error check when assigning to declared variable if isinstance(self.stmt.target, sri_ast.Name): # Do not allow assignment to undefined variables without annotation if self.stmt.target.id not in self.context.vars: raise VariableDeclarationException( "Variable type not defined", self.stmt) # Check against implicit conversion self._check_implicit_conversion(self.stmt.target.id, sub) is_valid_tuple_assign = (isinstance( self.stmt.target, sri_ast.Tuple)) and isinstance( self.stmt.value, sri_ast.Tuple) # Do no allow tuple-to-tuple assignment if is_valid_tuple_assign: raise VariableDeclarationException( "Tuple to tuple assignment not supported", self.stmt, ) # Checks to see if assignment is valid target = self.get_target(self.stmt.target) if isinstance(target.typ, ContractType) and not isinstance( sub.typ, ContractType): raise TypeMismatch( 'Contract assignment expects casted address: ' f'{target.typ}(<address_var>)', self.stmt) o = make_setter(target, sub, target.location, pos=getpos(self.stmt)) o.pos = getpos(self.stmt) return o
def call_lookup_specs(stmt_expr, context): from srilang.parser.expr import Expr method_name = stmt_expr.func.attr if len(stmt_expr.keywords): raise TypeMismatch( "Cannot use keyword arguments in calls to functions via 'self'", stmt_expr, ) expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args] sig = FunctionSignature.lookup_sig( context.sigs, method_name, expr_args, stmt_expr, context, ) return method_name, expr_args, sig
def pack_logging_topics(event_id, args, expected_topics, context, pos): topics = [event_id] code_pos = pos for pos, expected_topic in enumerate(expected_topics): expected_type = expected_topic.typ arg = args[pos] value = Expr(arg, context).lll_node arg_type = value.typ if isinstance(arg_type, ByteArrayLike) and isinstance(expected_type, ByteArrayLike): if arg_type.maxlen > expected_type.maxlen: raise TypeMismatch( f"Topic input bytes are too big: {arg_type} {expected_type}", code_pos ) if isinstance(arg, sri_ast.Str): bytez, bytez_length = string_to_bytes(arg.s) if len(bytez) > 32: raise InvalidLiteral( "Can only log a maximum of 32 bytes at a time.", code_pos ) topics.append(bytes_to_int(bytez + b'\x00' * (32 - bytez_length))) else: if value.location == "memory": size = ['mload', value] elif value.location == "storage": size = ['sload', ['sha3_32', value]] topics.append(byte_array_to_num(value, arg, 'uint256', size)) else: if arg_type != expected_type: raise TypeMismatch( f"Invalid type for logging topic, got {arg_type} expected {expected_type}", value.pos ) value = unwrap_location(value) value = base_type_conversion(value, arg_type, expected_type, pos=code_pos) topics.append(value) return topics
def external_contract_call(node, context, contract_name, contract_address, pos, value=None, gas=None): from srilang.parser.expr import ( Expr, ) if value is None: value = 0 if gas is None: gas = 'gas' if not contract_name: raise StructureException( f'Invalid external contract call "{node.func.attr}".', node) if contract_name not in context.sigs: raise VariableDeclarationException( f'Contract "{contract_name}" not declared yet', node) if contract_address.value == "address": raise StructureException(f"External calls to self are not permitted.", node) method_name = node.func.attr if method_name not in context.sigs[contract_name]: raise FunctionDeclarationException(( f"Function not declared yet: {method_name} (reminder: " "function must be declared in the correct contract)" f"The available methods are: {','.join(context.sigs[contract_name].keys())}" ), node.func) sig = context.sigs[contract_name][method_name] inargs, inargsize, _ = pack_arguments( sig, [Expr(arg, context).lll_node for arg in node.args], context, node.func, ) output_placeholder, output_size, returner = get_external_contract_call_output( sig, context) sub = [ 'seq', ['assert', ['extcodesize', contract_address]], ['assert', ['ne', 'address', contract_address]], ] if context.is_constant() and not sig.const: raise ConstancyViolation( f"May not call non-constant function '{method_name}' within {context.pp_constancy()}." " For asserting the result of modifiable contract calls, try assert_modifiable.", node) if context.is_constant() or sig.const: sub.append([ 'assert', [ 'staticcall', gas, contract_address, inargs, inargsize, output_placeholder, output_size, ] ]) else: sub.append([ 'assert', [ 'call', gas, contract_address, value, inargs, inargsize, output_placeholder, output_size, ] ]) sub.extend(returner) o = LLLnode.from_list(sub, typ=sig.output_type, location='memory', pos=getpos(node)) return o
def make_external_call(stmt_expr, context): from srilang.parser.expr import Expr value, gas = get_external_contract_keywords(stmt_expr, context) if (isinstance(stmt_expr.func, sri_ast.Attribute) and isinstance(stmt_expr.func.value, sri_ast.Call)): contract_name = stmt_expr.func.value.func.id contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0], context) return external_contract_call( stmt_expr, context, contract_name, contract_address, pos=getpos(stmt_expr), value=value, gas=gas, ) elif isinstance( stmt_expr.func.value, sri_ast.Attribute ) and stmt_expr.func.value.attr in context.sigs: # noqa: E501 contract_name = stmt_expr.func.value.attr var = context.globals[stmt_expr.func.value.attr] contract_address = unwrap_location( LLLnode.from_list( var.pos, typ=var.typ, location='storage', pos=getpos(stmt_expr), annotation='self.' + stmt_expr.func.value.attr, )) return external_contract_call( stmt_expr, context, contract_name, contract_address, pos=getpos(stmt_expr), value=value, gas=gas, ) elif (isinstance(stmt_expr.func.value, sri_ast.Attribute) and stmt_expr.func.value.attr in context.globals and hasattr(context.globals[stmt_expr.func.value.attr].typ, 'name')): contract_name = context.globals[stmt_expr.func.value.attr].typ.name var = context.globals[stmt_expr.func.value.attr] contract_address = unwrap_location( LLLnode.from_list( var.pos, typ=var.typ, location='storage', pos=getpos(stmt_expr), annotation='self.' + stmt_expr.func.value.attr, )) return external_contract_call( stmt_expr, context, contract_name, contract_address, pos=getpos(stmt_expr), value=value, gas=gas, ) else: raise StructureException("Unsupported operator.", stmt_expr)
def pack_args_by_32(holder, maxlen, arg, typ, context, placeholder, dynamic_offset_counter=None, datamem_start=None, pos=None): """ Copy necessary variables to pre-allocated memory section. :param holder: Complete holder for all args :param maxlen: Total length in bytes of the full arg section (static + dynamic). :param arg: Current arg to pack :param context: Context of arg :param placeholder: Static placeholder for static argument part. :param dynamic_offset_counter: position counter stored in static args. :param dynamic_placeholder: pointer to current position in memory to write dynamic values to. :param datamem_start: position where the whole datemem section starts. """ if isinstance(typ, BaseType): if isinstance(arg, LLLnode): value = unwrap_location(arg) else: value = Expr(arg, context).lll_node value = base_type_conversion(value, value.typ, typ, pos) holder.append(LLLnode.from_list(['mstore', placeholder, value], typ=typ, location='memory')) elif isinstance(typ, ByteArrayLike): if isinstance(arg, LLLnode): # Is prealloacted variable. source_lll = arg else: source_lll = Expr(arg, context).lll_node # Set static offset, in arg slot. holder.append(LLLnode.from_list(['mstore', placeholder, ['mload', dynamic_offset_counter]])) # Get the biginning to write the ByteArray to. dest_placeholder = LLLnode.from_list( ['add', datamem_start, ['mload', dynamic_offset_counter]], typ=typ, location='memory', annotation="pack_args_by_32:dest_placeholder") copier = make_byte_array_copier(dest_placeholder, source_lll, pos=pos) holder.append(copier) # Add zero padding. holder.append(zero_pad(dest_placeholder)) # Increment offset counter. increment_counter = LLLnode.from_list([ 'mstore', dynamic_offset_counter, [ 'add', ['add', ['mload', dynamic_offset_counter], ['ceil32', ['mload', dest_placeholder]]], 32, ], ], annotation='Increment dynamic offset counter') holder.append(increment_counter) elif isinstance(typ, ListType): maxlen += (typ.count - 1) * 32 typ = typ.subtype def check_list_type_match(provided): # Check list types match. if provided != typ: raise TypeMismatch( f"Log list type '{provided}' does not match provided, expected '{typ}'" ) # NOTE: Below code could be refactored into iterators/getter functions for each type of # repetitive loop. But seeing how each one is a unique for loop, and in which way # the sub value makes the difference in each type of list clearer. # List from storage if isinstance(arg, sri_ast.Attribute) and arg.value.id == 'self': stor_list = context.globals[arg.attr] check_list_type_match(stor_list.typ.subtype) size = stor_list.typ.count mem_offset = 0 for i in range(0, size): storage_offset = i arg2 = LLLnode.from_list( ['sload', ['add', ['sha3_32', Expr(arg, context).lll_node], storage_offset]], typ=typ, ) holder, maxlen = pack_args_by_32( holder, maxlen, arg2, typ, context, placeholder + mem_offset, pos=pos, ) mem_offset += get_size_of_type(typ) * 32 # List from variable. elif isinstance(arg, sri_ast.Name): size = context.vars[arg.id].size pos = context.vars[arg.id].pos check_list_type_match(context.vars[arg.id].typ.subtype) mem_offset = 0 for _ in range(0, size): arg2 = LLLnode.from_list( pos + mem_offset, typ=typ, location=context.vars[arg.id].location ) holder, maxlen = pack_args_by_32( holder, maxlen, arg2, typ, context, placeholder + mem_offset, pos=pos, ) mem_offset += get_size_of_type(typ) * 32 # List from list literal. else: mem_offset = 0 for arg2 in arg.elts: holder, maxlen = pack_args_by_32( holder, maxlen, arg2, typ, context, placeholder + mem_offset, pos=pos, ) mem_offset += get_size_of_type(typ) * 32 return holder, maxlen
def pack_logging_data(expected_data, args, context, pos): # Checks to see if there's any data if not args: return ['seq'], 0, None, 0 holder = ['seq'] maxlen = len(args) * 32 # total size of all packed args (upper limit) # Unroll any function calls, to temp variables. prealloacted = {} for idx, (arg, _expected_arg) in enumerate(zip(args, expected_data)): if isinstance(arg, (sri_ast.Str, sri_ast.Call)): expr = Expr(arg, context) source_lll = expr.lll_node typ = source_lll.typ if isinstance(arg, sri_ast.Str): if len(arg.s) > typ.maxlen: raise TypeMismatch( f"Data input bytes are to big: {len(arg.s)} {typ}", pos ) tmp_variable = context.new_internal_variable( f'_log_pack_var_{arg.lineno}_{arg.col_offset}', source_lll.typ, ) tmp_variable_node = LLLnode.from_list( tmp_variable, typ=source_lll.typ, pos=getpos(arg), location="memory", annotation=f'log_prealloacted {source_lll.typ}', ) # Store len. # holder.append(['mstore', len_placeholder, ['mload', unwrap_location(source_lll)]]) # Copy bytes. holder.append( make_setter(tmp_variable_node, source_lll, pos=getpos(arg), location='memory') ) prealloacted[idx] = tmp_variable_node requires_dynamic_offset = any([isinstance(data.typ, ByteArrayLike) for data in expected_data]) if requires_dynamic_offset: dynamic_offset_counter = context.new_placeholder(BaseType(32)) dynamic_placeholder = context.new_placeholder(BaseType(32)) else: dynamic_offset_counter = None # Create placeholder for static args. Note: order of new_*() is important. placeholder_map = {} for i, (_arg, data) in enumerate(zip(args, expected_data)): typ = data.typ if not isinstance(typ, ByteArrayLike): placeholder = context.new_placeholder(typ) else: placeholder = context.new_placeholder(BaseType(32)) placeholder_map[i] = placeholder # Populate static placeholders. for i, (arg, data) in enumerate(zip(args, expected_data)): typ = data.typ placeholder = placeholder_map[i] if not isinstance(typ, ByteArrayLike): holder, maxlen = pack_args_by_32( holder, maxlen, prealloacted.get(i, arg), typ, context, placeholder, pos=pos, ) # Dynamic position starts right after the static args. if requires_dynamic_offset: holder.append(LLLnode.from_list(['mstore', dynamic_offset_counter, maxlen])) # Calculate maximum dynamic offset placeholders, used for gas estimation. for _arg, data in zip(args, expected_data): typ = data.typ if isinstance(typ, ByteArrayLike): maxlen += 32 + ceil32(typ.maxlen) if requires_dynamic_offset: datamem_start = dynamic_placeholder + 32 else: datamem_start = placeholder_map[0] # Copy necessary data into allocated dynamic section. for i, (arg, data) in enumerate(zip(args, expected_data)): typ = data.typ if isinstance(typ, ByteArrayLike): pack_args_by_32( holder=holder, maxlen=maxlen, arg=prealloacted.get(i, arg), typ=typ, context=context, placeholder=placeholder_map[i], datamem_start=datamem_start, dynamic_offset_counter=dynamic_offset_counter, pos=pos ) return holder, maxlen, dynamic_offset_counter, datamem_start
def parse_for(self): # Type 0 for, e.g. for i in list(): ... if self._is_list_iter(): return self.parse_for_list() if not isinstance(self.stmt.iter, sri_ast.Call): if isinstance(self.stmt.iter, sri_ast.Subscript): raise StructureException("Cannot iterate over a nested list", self.stmt.iter) raise StructureException( f"Cannot iterate over '{type(self.stmt.iter).__name__}' object", self.stmt.iter) if getattr(self.stmt.iter.func, 'id', None) != "range": raise StructureException( "Non-literals cannot be used as loop range", self.stmt.iter.func) if len(self.stmt.iter.args) not in {1, 2}: raise StructureException( f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}", self.stmt.iter.func) block_scope_id = id(self.stmt) with self.context.make_blockscope(block_scope_id): # Get arg0 arg0 = self.stmt.iter.args[0] num_of_args = len(self.stmt.iter.args) # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: arg0_val = self._get_range_const_value(arg0) start = LLLnode.from_list(0, typ='int128', pos=getpos(self.stmt)) rounds = arg0_val # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1], raise_exception=False)[0]: arg0_val = self._get_range_const_value(arg0) arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) start = LLLnode.from_list(arg0_val, typ='int128', pos=getpos(self.stmt)) rounds = LLLnode.from_list(arg1_val - arg0_val, typ='int128', pos=getpos(self.stmt)) # Type 3 for, e.g. for i in range(x, x + 10): ... else: arg1 = self.stmt.iter.args[1] if not isinstance(arg1, sri_ast.BinOp) or not isinstance( arg1.op, sri_ast.Add): raise StructureException( ("Two-arg for statements must be of the form `for i " "in range(start, start + rounds): ...`"), arg1, ) if not sri_ast.compare_nodes(arg0, arg1.left): raise StructureException( ("Two-arg for statements of the form `for i in " "range(x, x + y): ...` must have x identical in both " f"places: {sri_ast.ast_to_dict(arg0)} {sri_ast.ast_to_dict(arg1.left)}" ), self.stmt.iter, ) rounds = self._get_range_const_value(arg1.right) start = Expr.parse_value_expr(arg0, self.context) r = rounds if isinstance(rounds, int) else rounds.value if r < 1: raise StructureException( f"For loop has invalid number of iterations ({r})," " the value must be greater than zero", self.stmt.iter) varname = self.stmt.target.id pos = self.context.new_variable(varname, BaseType('int128'), pos=getpos(self.stmt)) self.context.forvars[varname] = True o = LLLnode.from_list( [ 'repeat', pos, start, rounds, parse_body(self.stmt.body, self.context) ], typ=None, pos=getpos(self.stmt), ) del self.context.vars[varname] del self.context.forvars[varname] return o
def parse_return(self): if self.context.return_type is None: if self.stmt.value: raise TypeMismatch("Not expecting to return a value", self.stmt) return LLLnode.from_list( make_return_stmt(self.stmt, self.context, 0, 0), typ=None, pos=getpos(self.stmt), valency=0, ) if not self.stmt.value: raise TypeMismatch("Expecting to return a value", self.stmt) sub = Expr(self.stmt.value, self.context).lll_node # Returning a value (most common case) if isinstance(sub.typ, BaseType): sub = unwrap_location(sub) if self.context.return_type != sub.typ and not sub.typ.is_literal: raise TypeMismatch( f"Trying to return base type {sub.typ}, output expecting " f"{self.context.return_type}", self.stmt.value, ) elif sub.typ.is_literal and ( self.context.return_type.typ == sub.typ or 'int' in self.context.return_type.typ and 'int' in sub.typ.typ): # noqa: E501 if not SizeLimits.in_bounds(self.context.return_type.typ, sub.value): raise InvalidLiteral( "Number out of range: " + str(sub.value), self.stmt) else: return LLLnode.from_list( [ 'seq', ['mstore', 0, sub], make_return_stmt(self.stmt, self.context, 0, 32) ], typ=None, pos=getpos(self.stmt), valency=0, ) elif is_base_type(sub.typ, self.context.return_type.typ) or ( is_base_type(sub.typ, 'int128') and is_base_type( self.context.return_type, 'int256')): # noqa: E501 return LLLnode.from_list( [ 'seq', ['mstore', 0, sub], make_return_stmt(self.stmt, self.context, 0, 32) ], typ=None, pos=getpos(self.stmt), valency=0, ) else: raise TypeMismatch( f"Unsupported type conversion: {sub.typ} to {self.context.return_type}", self.stmt.value, ) # Returning a byte array elif isinstance(sub.typ, ByteArrayLike): if not sub.typ.eq_base(self.context.return_type): raise TypeMismatch( f"Trying to return base type {sub.typ}, output expecting " f"{self.context.return_type}", self.stmt.value, ) if sub.typ.maxlen > self.context.return_type.maxlen: raise TypeMismatch( f"Cannot cast from greater max-length {sub.typ.maxlen} to shorter " f"max-length {self.context.return_type.maxlen}", self.stmt.value, ) # loop memory has to be allocated first. loop_memory_position = self.context.new_placeholder( typ=BaseType('uint256')) # len & bytez placeholder have to be declared after each other at all times. len_placeholder = self.context.new_placeholder( typ=BaseType('uint256')) bytez_placeholder = self.context.new_placeholder(typ=sub.typ) if sub.location in ('storage', 'memory'): return LLLnode.from_list([ 'seq', make_byte_array_copier(LLLnode( bytez_placeholder, location='memory', typ=sub.typ), sub, pos=getpos(self.stmt)), zero_pad(bytez_placeholder), ['mstore', len_placeholder, 32], make_return_stmt( self.stmt, self.context, len_placeholder, ['ceil32', ['add', ['mload', bytez_placeholder], 64]], loop_memory_position=loop_memory_position, ) ], typ=None, pos=getpos(self.stmt), valency=0) else: raise Exception(f"Invalid location: {sub.location}") elif isinstance(sub.typ, ListType): loop_memory_position = self.context.new_placeholder( typ=BaseType('uint256')) if sub.typ != self.context.return_type: raise TypeMismatch( f"List return type {sub.typ} does not match specified " f"return type, expecting {self.context.return_type}", self.stmt) elif sub.location == "memory" and sub.value != "multi": return LLLnode.from_list( make_return_stmt( self.stmt, self.context, sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position, ), typ=None, pos=getpos(self.stmt), valency=0, ) else: new_sub = LLLnode.from_list( self.context.new_placeholder(self.context.return_type), typ=self.context.return_type, location='memory', ) setter = make_setter(new_sub, sub, 'memory', pos=getpos(self.stmt)) return LLLnode.from_list([ 'seq', setter, make_return_stmt( self.stmt, self.context, new_sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position, ) ], typ=None, pos=getpos(self.stmt)) # Returning a struct elif isinstance(sub.typ, StructType): retty = self.context.return_type if not isinstance(retty, StructType) or retty.name != sub.typ.name: raise TypeMismatch( f"Trying to return {sub.typ}, output expecting {self.context.return_type}", self.stmt.value, ) return gen_tuple_return(self.stmt, self.context, sub) # Returning a tuple. elif isinstance(sub.typ, TupleType): if not isinstance(self.context.return_type, TupleType): raise TypeMismatch( f"Trying to return tuple type {sub.typ}, output expecting " f"{self.context.return_type}", self.stmt.value, ) if len(self.context.return_type.members) != len(sub.typ.members): raise StructureException("Tuple lengths don't match!", self.stmt) # check return type matches, sub type. for i, ret_x in enumerate(self.context.return_type.members): s_member = sub.typ.members[i] sub_type = s_member if isinstance(s_member, NodeType) else s_member.typ if type(sub_type) is not type(ret_x): raise StructureException( "Tuple return type does not match annotated return. " f"{type(sub_type)} != {type(ret_x)}", self.stmt) return gen_tuple_return(self.stmt, self.context, sub) else: raise TypeMismatch(f"Can't return type {sub.typ}", self.stmt)
def aug_assign(self): target = self.get_target(self.stmt.target) sub = Expr.parse_value_expr(self.stmt.value, self.context) if not isinstance(self.stmt.op, (sri_ast.Add, sri_ast.Sub, sri_ast.Mult, sri_ast.Div, sri_ast.Mod)): raise StructureException("Unsupported operator for augassign", self.stmt) if not isinstance(target.typ, BaseType): raise TypeMismatch( "Can only use aug-assign operators with simple types!", self.stmt.target) if target.location == 'storage': o = Expr.parse_value_expr( sri_ast.BinOp( left=LLLnode.from_list(['sload', '_stloc'], typ=target.typ, pos=target.pos), right=sub, op=self.stmt.op, lineno=self.stmt.lineno, col_offset=self.stmt.col_offset, end_lineno=self.stmt.end_lineno, end_col_offset=self.stmt.end_col_offset, ), self.context, ) return LLLnode.from_list([ 'with', '_stloc', target, [ 'sstore', '_stloc', base_type_conversion( o, o.typ, target.typ, pos=getpos(self.stmt)), ], ], typ=None, pos=getpos(self.stmt)) elif target.location == 'memory': o = Expr.parse_value_expr( sri_ast.BinOp( left=LLLnode.from_list(['mload', '_mloc'], typ=target.typ, pos=target.pos), right=sub, op=self.stmt.op, lineno=self.stmt.lineno, col_offset=self.stmt.col_offset, end_lineno=self.stmt.end_lineno, end_col_offset=self.stmt.end_col_offset, ), self.context, ) return LLLnode.from_list([ 'with', '_mloc', target, [ 'mstore', '_mloc', base_type_conversion( o, o.typ, target.typ, pos=getpos(self.stmt)), ], ], typ=None, pos=getpos(self.stmt))
def parse_for_list(self): with self.context.range_scope(): iter_list_node = Expr(self.stmt.iter, self.context).lll_node if not isinstance(iter_list_node.typ.subtype, BaseType): # Sanity check on list subtype. raise StructureException( 'For loops allowed only on basetype lists.', self.stmt.iter) iter_var_type = (self.context.vars.get(self.stmt.iter.id).typ if isinstance(self.stmt.iter, sri_ast.Name) else None) subtype = iter_list_node.typ.subtype.typ varname = self.stmt.target.id value_pos = self.context.new_variable( varname, BaseType(subtype), ) i_pos_raw_name = '_index_for_' + varname i_pos = self.context.new_internal_variable( i_pos_raw_name, BaseType(subtype), ) self.context.forvars[varname] = True # Is a list that is already allocated to memory. if iter_var_type: list_name = self.stmt.iter.id # make sure list cannot be altered whilst iterating. with self.context.in_for_loop_scope(list_name): iter_var = self.context.vars.get(self.stmt.iter.id) if iter_var.location == 'calldata': fetcher = 'calldataload' elif iter_var.location == 'memory': fetcher = 'mload' else: raise CompilerPanic( f'List iteration only supported on in-memory types {self.expr}', ) body = [ 'seq', [ 'mstore', value_pos, [ fetcher, [ 'add', iter_var.pos, ['mul', ['mload', i_pos], 32] ] ], ], parse_body(self.stmt.body, self.context) ] o = LLLnode.from_list( ['repeat', i_pos, 0, iter_var.size, body], typ=None, pos=getpos(self.stmt)) # List gets defined in the for statement. elif isinstance(self.stmt.iter, sri_ast.List): # Allocate list to memory. count = iter_list_node.typ.count tmp_list = LLLnode.from_list(obj=self.context.new_placeholder( ListType(iter_list_node.typ.subtype, count)), typ=ListType( iter_list_node.typ.subtype, count), location='memory') setter = make_setter(tmp_list, iter_list_node, 'memory', pos=getpos(self.stmt)) body = [ 'seq', [ 'mstore', value_pos, [ 'mload', ['add', tmp_list, ['mul', ['mload', i_pos], 32]] ] ], parse_body(self.stmt.body, self.context) ] o = LLLnode.from_list( ['seq', setter, ['repeat', i_pos, 0, count, body]], typ=None, pos=getpos(self.stmt)) # List contained in storage. elif isinstance(self.stmt.iter, sri_ast.Attribute): count = iter_list_node.typ.count list_name = iter_list_node.annotation # make sure list cannot be altered whilst iterating. with self.context.in_for_loop_scope(list_name): body = [ 'seq', [ 'mstore', value_pos, [ 'sload', [ 'add', ['sha3_32', iter_list_node], ['mload', i_pos] ] ] ], parse_body(self.stmt.body, self.context), ] o = LLLnode.from_list( ['seq', ['repeat', i_pos, 0, count, body]], typ=None, pos=getpos(self.stmt)) del self.context.vars[varname] # this kind of open access to the vars dict should be disallowed. # we should use member functions to provide an API for these kinds # of operations. del self.context.vars[self.context._mangle(i_pos_raw_name)] del self.context.forvars[varname] return o
def parse_private_function(code: ast.FunctionDef, sig: FunctionSignature, context: Context) -> LLLnode: """ Parse a private function (FuncDef), and produce full function body. :param sig: the FuntionSignature :param code: ast of function :return: full sig compare & function body """ validate_private_function(code, sig) # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock( sig, context.global_ctx) # Create callback_ptr, this stores a destination in the bytecode for a private # function to jump to after a function has executed. clampers: List[LLLnode] = [] # Allocate variable space. context.memory_allocator.increase_memory(sig.max_copy_size) _post_callback_ptr = f"{sig.name}_{sig.method_id}_post_callback_ptr" context.callback_ptr = context.new_placeholder(typ=BaseType('uint256')) clampers.append( LLLnode.from_list( ['mstore', context.callback_ptr, 'pass'], annotation='pop callback pointer', )) if sig.total_default_args > 0: clampers.append(LLLnode.from_list(['label', _post_callback_ptr])) # private functions without return types need to jump back to # the calling function, as there is no return statement to handle the # jump. if sig.output_type is None: stop_func = [['jump', ['mload', context.callback_ptr]]] else: stop_func = [['stop']] # Generate copiers if len(sig.base_args) == 0: copier = ['pass'] clampers.append(LLLnode.from_list(copier)) elif sig.total_default_args == 0: copier = get_private_arg_copier( total_size=sig.base_copy_size, memory_dest=MemoryPositions.RESERVED_MEMORY) clampers.append(LLLnode.from_list(copier)) # Fill variable positions for arg in sig.args: if isinstance(arg.typ, ByteArrayLike): mem_pos, _ = context.memory_allocator.increase_memory( 32 * get_size_of_type(arg.typ)) context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ, False) else: context.vars[arg.name] = VariableRecord( arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False, ) # Private function copiers. No clamping for private functions. dyn_variable_names = [ a.name for a in sig.base_args if isinstance(a.typ, ByteArrayLike) ] if dyn_variable_names: i_placeholder = context.new_placeholder(typ=BaseType('uint256')) unpackers: List[Any] = [] for idx, var_name in enumerate(dyn_variable_names): var = context.vars[var_name] ident = f"_load_args_{sig.method_id}_dynarg{idx}" o = make_unpacker(ident=ident, i_placeholder=i_placeholder, begin_pos=var.pos) unpackers.append(o) if not unpackers: unpackers = ['pass'] # 0 added to complete full overarching 'seq' statement, see private_label. unpackers.append(0) clampers.append( LLLnode.from_list( ['seq_unchecked'] + unpackers, typ=None, annotation='dynamic unpacker', pos=getpos(code), )) # Function has default arguments. if sig.total_default_args > 0: # Function with default parameters. default_sigs = sig_utils.generate_default_arg_sigs( code, context.sigs, context.global_ctx) sig_chain: List[Any] = ['seq'] for default_sig in default_sigs: sig_compare, private_label = get_sig_statements( default_sig, getpos(code)) # Populate unset default variables set_defaults = [] for arg_name in get_default_names_to_set(sig, default_sig): value = Expr(sig.default_values[arg_name], context).lll_node var = context.vars[arg_name] left = LLLnode.from_list(var.pos, typ=var.typ, location='memory', pos=getpos(code), mutable=var.mutable) set_defaults.append( make_setter(left, value, 'memory', pos=getpos(code))) current_sig_arg_names = [x.name for x in default_sig.args] # Load all variables in default section, if private, # because the stack is a linear pipe. copier_arg_count = len(default_sig.args) copier_arg_names = current_sig_arg_names # Order copier_arg_names, this is very important. copier_arg_names = [ x.name for x in default_sig.args if x.name in copier_arg_names ] # Variables to be populated from calldata/stack. default_copiers: List[Any] = [] if copier_arg_count > 0: # Get map of variables in calldata, with thier offsets offset = 4 calldata_offset_map = {} for arg in default_sig.args: calldata_offset_map[arg.name] = offset offset += (32 if isinstance(arg.typ, ByteArrayLike) else get_size_of_type(arg.typ) * 32) # Copy set default parameters from calldata dynamics = [] for arg_name in copier_arg_names: var = context.vars[arg_name] if isinstance(var.typ, ByteArrayLike): _size = 32 dynamics.append(var.pos) else: _size = var.size * 32 default_copiers.append( get_private_arg_copier( memory_dest=var.pos, total_size=_size, )) # Unpack byte array if necessary. if dynamics: i_placeholder = context.new_placeholder( typ=BaseType('uint256')) for idx, var_pos in enumerate(dynamics): ident = f'unpack_default_sig_dyn_{default_sig.method_id}_arg{idx}' default_copiers.append( make_unpacker( ident=ident, i_placeholder=i_placeholder, begin_pos=var_pos, )) default_copiers.append(0) # for over arching seq, POP sig_chain.append([ 'if', sig_compare, [ 'seq', private_label, LLLnode.from_list([ 'mstore', context.callback_ptr, 'pass', ], annotation='pop callback pointer', pos=getpos(code)), ['seq'] + set_defaults if set_defaults else ['pass'], ['seq_unchecked'] + default_copiers if default_copiers else ['pass'], ['goto', _post_callback_ptr] ] ]) # With private functions all variable loading occurs in the default # function sub routine. _clampers = [['label', _post_callback_ptr]] # Function with default parameters. o = LLLnode.from_list( [ 'seq', sig_chain, [ 'if', 0, # can only be jumped into [ 'seq', ['seq'] + nonreentrant_pre + _clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + stop_func ], ], ], typ=None, pos=getpos(code)) else: # Function without default parameters. sig_compare, private_label = get_sig_statements(sig, getpos(code)) o = LLLnode.from_list([ 'if', sig_compare, ['seq'] + [private_label] + nonreentrant_pre + clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + stop_func ], typ=None, pos=getpos(code)) return o return o
def process_arg(index, arg, expected_arg_typelist, function_name, context): if isinstance(expected_arg_typelist, Optional): expected_arg_typelist = expected_arg_typelist.typ if not isinstance(expected_arg_typelist, tuple): expected_arg_typelist = (expected_arg_typelist, ) vsub = None for expected_arg in expected_arg_typelist: if expected_arg == 'num_literal': if context.constants.is_constant_of_base_type(arg, ('uint256', 'int128')): return context.constants.get_constant(arg.id, None).value if isinstance(arg, (sri_ast.Int, sri_ast.Decimal)): return arg.n elif expected_arg == 'str_literal': if isinstance(arg, sri_ast.Str): bytez = b'' for c in arg.s: if ord(c) >= 256: raise InvalidLiteral( f"Cannot insert special character {c} into byte array", arg, ) bytez += bytes([ord(c)]) return bytez elif expected_arg == 'bytes_literal': if isinstance(arg, sri_ast.Bytes): return arg.s elif expected_arg == 'name_literal': if isinstance(arg, sri_ast.Name): return arg.id elif isinstance(arg, sri_ast.Subscript) and arg.value.id == 'bytes': return f'bytes[{arg.slice.value.n}]' elif expected_arg == '*': return arg elif expected_arg == 'bytes': sub = Expr(arg, context).lll_node if isinstance(sub.typ, ByteArrayType): return sub elif expected_arg == 'string': sub = Expr(arg, context).lll_node if isinstance(sub.typ, StringType): return sub else: # Does not work for unit-endowed types inside compound types, e.g. timestamp[2] parsed_expected_type = context.parse_type( sri_ast.parse_to_ast(expected_arg)[0].value, 'memory', ) if isinstance(parsed_expected_type, BaseType): vsub = vsub or Expr.parse_value_expr(arg, context) is_valid_integer = ( ( expected_arg in ('int128', 'uint256') and isinstance(vsub.typ, BaseType) ) and ( vsub.typ.typ in ('int128', 'uint256') and vsub.typ.is_literal ) and ( SizeLimits.in_bounds(expected_arg, vsub.value) ) ) if is_base_type(vsub.typ, expected_arg): return vsub elif is_valid_integer: return vsub else: vsub = vsub or Expr(arg, context).lll_node if vsub.typ == parsed_expected_type: return Expr(arg, context).lll_node if len(expected_arg_typelist) == 1: raise TypeMismatch( f"Expecting {expected_arg} for argument {index} of {function_name}", arg ) else: raise TypeMismatch( f"Expecting one of {expected_arg_typelist} for argument {index} of {function_name}", arg )