def generate_default_arg_sigs(code, _contracts, _custom_units): # generate all sigs, and attach. total_default_args = len(code.args.defaults) if total_default_args == 0: return [FunctionSignature.from_definition(code, sigs=_contracts, custom_units=_custom_units)] base_args = code.args.args[:-total_default_args] default_args = code.args.args[-total_default_args:] # Generate a list of default function combinations. row = [False] * (total_default_args) table = [row.copy()] for i in range(total_default_args): row[i] = True table.append(row.copy()) default_sig_strs = [] sig_fun_defs = [] for truth_row in table: new_code = copy.deepcopy(code) new_code.args.args = copy.deepcopy(base_args) new_code.args.default = [] # Add necessary default args. for idx, val in enumerate(truth_row): if val is True: new_code.args.args.append(default_args[idx]) sig = FunctionSignature.from_definition(new_code, sigs=_contracts, custom_units=_custom_units) default_sig_strs.append(sig.sig) sig_fun_defs.append(sig) return sig_fun_defs
def mk_full_signature(global_ctx, sig_formatter=None): if sig_formatter is None: # Use default JSON style output. sig_formatter = _default_sig_formatter o = [] # Produce event signatues. for code in global_ctx._events: sig = EventSignature.from_declaration(code, global_ctx) o.append(sig_formatter(sig)) # Produce function signatures. for code in global_ctx._defs: sig = FunctionSignature.from_definition( code, sigs=global_ctx._contracts, custom_structs=global_ctx._structs, constants=global_ctx._constants, ) if not sig.internal: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx) for s in default_sigs: o.append(sig_formatter(s)) return o
def call_lookup_specs(stmt_expr, context): from vyper.parser.expr import Expr method_name = stmt_expr.func.attr expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args] sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args, stmt_expr, context) return method_name, expr_args, sig
def parse_external_contracts(external_contracts, _contracts, _structs): for _contractname in _contracts: _contract_defs = _contracts[_contractname] _defnames = [_def.name for _def in _contract_defs] contract = {} if len(set(_defnames)) < len(_contract_defs): raise FunctionDeclarationException( "Duplicate function name: %s" % [name for name in _defnames if _defnames.count(name) > 1][0]) for _def in _contract_defs: constant = False # test for valid call type keyword. if len(_def.body) == 1 and \ isinstance(_def.body[0], ast.Expr) and \ isinstance(_def.body[0].value, ast.Name) and \ _def.body[0].value.id in ('modifying', 'constant'): constant = True if _def.body[ 0].value.id == 'constant' else False else: raise StructureException( 'constant or modifying call type must be specified', _def) # Recognizes already-defined structs sig = FunctionSignature.from_definition(_def, contract_def=True, constant=constant, custom_structs=_structs) contract[sig.name] = sig external_contracts[_contractname] = contract return external_contracts
def mk_full_signature(code, sig_formatter=None, interface_codes=None): if sig_formatter is None: # Use default JSON style output. sig_formatter = _default_sig_formatter o = [] global_ctx = GlobalContext.get_global_context( code, interface_codes=interface_codes) # Produce event signatues. for code in global_ctx._events: sig = EventSignature.from_declaration(code, global_ctx) o.append(sig_formatter(sig, global_ctx._custom_units_descriptions)) # Produce function signatures. for code in global_ctx._defs: sig = FunctionSignature.from_definition( code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units, custom_structs=global_ctx._structs, constants=global_ctx._constants) if not sig.private: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx) for s in default_sigs: o.append( sig_formatter(s, global_ctx._custom_units_descriptions)) return o
def validate_external_function(code: ast.FunctionDef, sig: FunctionSignature, global_ctx: GlobalContext) -> None: """ Validate external function definition. """ # __init__ function may not have defaults. if sig.is_initializer() and sig.total_default_args > 0: raise FunctionDeclarationException( "__init__ function may not have default parameters.", code)
def mk_full_signature(code): o = [] _contracts, _events, _defs, _globals, _custom_units = get_contracts_and_defs_and_globals(code) for code in _events: sig = EventSignature.from_declaration(code, custom_units=_custom_units) o.append(sig.to_abi_dict()) for code in _defs: sig = FunctionSignature.from_definition(code, sigs=_contracts, custom_units=_custom_units) if not sig.private: o.append(sig.to_abi_dict()) return o
def mk_method_identifiers(code): o = {} global_ctx = GlobalContext.get_global_context(parse(code)) for code in global_ctx._defs: sig = FunctionSignature.from_definition(code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units, constants=global_ctx._constants) if not sig.private: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx) for s in default_sigs: o[s.sig] = hex(s.method_id) return o
def mk_full_signature(code): o = [] global_ctx = GlobalContext.get_global_context(code) for code in global_ctx._events: sig = EventSignature.from_declaration(code, custom_units=global_ctx._custom_units) o.append(sig.to_abi_dict()) for code in global_ctx._defs: sig = FunctionSignature.from_definition(code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units) if not sig.private: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx._custom_units) for s in default_sigs: o.append(s.to_abi_dict()) return o
def mk_single_method_identifier(code, global_ctx): identifiers = {} sig = FunctionSignature.from_definition( code, sigs=global_ctx._contracts, custom_structs=global_ctx._structs, constants=global_ctx._constants, ) if not sig.private: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx) for s in default_sigs: identifiers[s.sig] = hex(s.method_id) return identifiers
def validate_public_function(code: ast.FunctionDef, sig: FunctionSignature, global_ctx: GlobalContext) -> None: """ Validate public function definition. """ # __init__ function may not have defaults. if sig.is_initializer() and sig.total_default_args > 0: raise FunctionDeclarationException( "__init__ function may not have default parameters.", code) # Check for duplicate variables with globals for arg in sig.args: if arg.name in global_ctx._globals: raise FunctionDeclarationException( "Variable name duplicated between " "function arguments and globals: " + arg.name, code)
def mk_full_signature_from_json(abi): funcs = [func for func in abi if func["type"] == "function"] sigs = [] for func in funcs: args = [] returns = None for a in func["inputs"]: arg = vy_ast.arg( arg=a["name"], annotation=abi_type_to_ast(a["type"], 1048576), lineno=0, col_offset=0, ) args.append(arg) if len(func["outputs"]) == 1: returns = abi_type_to_ast(func["outputs"][0]["type"], 1) elif len(func["outputs"]) > 1: returns = vy_ast.Tuple(elements=[ abi_type_to_ast(a["type"], 1) for a in func["outputs"] ]) decorator_list = [vy_ast.Name(id="external")] # Handle either constant/payable or stateMutability field if ("constant" in func and func["constant"]) or ("stateMutability" in func and func["stateMutability"] == "view"): decorator_list.append(vy_ast.Name(id="view")) if ("payable" in func and func["payable"]) or ("stateMutability" in func and func["stateMutability"] == "payable"): decorator_list.append(vy_ast.Name(id="payable")) sig = FunctionSignature.from_definition( code=vy_ast.FunctionDef( name=func["name"], args=vy_ast.arguments(args=args), decorator_list=decorator_list, returns=returns, ), custom_structs=dict(), constants=Constants(), is_from_json=True, ) sigs.append(sig) return sigs
def mk_full_signature(global_ctx, sig_formatter): o = [] # Produce function signatures. for code in global_ctx._defs: sig = FunctionSignature.from_definition( code, sigs=global_ctx._contracts, custom_structs=global_ctx._structs, ) if not sig.internal: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx) for s in default_sigs: o.append(sig_formatter(s)) return o
def mk_full_signature_from_json(abi): funcs = [func for func in abi if func['type'] == 'function'] sigs = [] for func in funcs: args = [] returns = None for a in func['inputs']: arg = ast.arg( arg=a['name'], annotation=abi_type_to_ast(a['type']), lineno=0, col_offset=0 ) args.append(arg) if len(func['outputs']) == 1: returns = abi_type_to_ast(func['outputs'][0]['type']) elif len(func['outputs']) > 1: returns = ast.Tuple( elts=[ abi_type_to_ast(a['type']) for a in func['outputs'] ] ) decorator_list = [ast.Name(id='public')] if func['constant']: decorator_list.append(ast.Name(id='constant')) if func['payable']: decorator_list.append(ast.Name(id='payable')) sig = FunctionSignature.from_definition( code=ast.FunctionDef( name=func['name'], args=ast.arguments(args=args), decorator_list=decorator_list, returns=returns, ), custom_units=set(), custom_structs=dict(), constants=Constants() ) sigs.append(sig) return sigs
def mk_method_identifiers(code): o = [] global_ctx = GlobalContext.get_global_context(parse(code)) for code in global_ctx._defs: sig = FunctionSignature.from_definition( code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units) if not sig.private: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx._custom_units) for s in default_sigs: o.append( s.get_method_identifier(global_ctx._contracts, global_ctx._custom_units)) return dict(o)
def parse_external_interfaces(external_interfaces, global_ctx): for _interfacename in global_ctx._contracts: _interface_defs = global_ctx._contracts[_interfacename] _defnames = [_def.name for _def in _interface_defs] interface = {} if len(set(_defnames)) < len(_interface_defs): raise FunctionDeclarationException( "Duplicate function name: " f"{[name for name in _defnames if _defnames.count(name) > 1][0]}" ) for _def in _interface_defs: constant = False # test for valid call type keyword. if (len(_def.body) == 1 and isinstance(_def.body[0], vy_ast.Expr) and isinstance(_def.body[0].value, vy_ast.Name) # NOTE: Can't import enums here because of circular import and _def.body[0].value.id in ("pure", "view", "nonpayable", "payable")): constant = True if _def.body[0].value.id in ("view", "pure") else False else: raise StructureException( "state mutability of call type must be specified", _def) # Recognizes already-defined structs sig = FunctionSignature.from_definition( _def, interface_def=True, constant_override=constant, custom_structs=global_ctx._structs, constants=global_ctx._constants, ) interface[sig.name] = sig external_interfaces[_interfacename] = interface for interface_name, interface in global_ctx._interfaces.items(): external_interfaces[interface_name] = { sig.name: sig for sig in interface if isinstance(sig, FunctionSignature) } return external_interfaces
def parse_other_functions(o, otherfuncs, _globals, sigs, external_contracts, origcode, _custom_units, fallback_function, runtime_only): sub = ['seq', initializer_lll] add_gas = initializer_lll.gas for _def in otherfuncs: sub.append(parse_func(_def, _globals, {**{'self': sigs}, **external_contracts}, origcode, _custom_units)) # noqa E999 sub[-1].total_gas += add_gas add_gas += 30 sig = FunctionSignature.from_definition(_def, external_contracts, custom_units=_custom_units) sig.gas = sub[-1].total_gas sigs[sig.name] = sig # Add fallback function if fallback_function: fallback_func = parse_func(fallback_function[0], _globals, {**{'self': sigs}, **external_contracts}, origcode, _custom_units) sub.append(fallback_func) else: sub.append(LLLnode.from_list(['revert', 0, 0], typ=None, annotation='Default function')) if runtime_only: return sub else: o.append(['return', 0, ['lll', sub, 0]]) return o
def _call_lookup_specs(stmt_expr, context): from vyper.parser.expr import Expr method_name = stmt_expr.func.attr if len(stmt_expr.keywords): raise TypeMismatch( "Cannot use keyword arguments in calls to functions via 'self'", stmt_expr, ) expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args] sig = FunctionSignature.lookup_sig( context.sigs, method_name, expr_args, stmt_expr, context, ) return method_name, expr_args, sig
def parse_external_contracts(external_contracts, global_ctx): for _contractname in global_ctx._contracts: _contract_defs = global_ctx._contracts[_contractname] _defnames = [_def.name for _def in _contract_defs] contract = {} if len(set(_defnames)) < len(_contract_defs): raise FunctionDeclarationException( "Duplicate function name: " f"{[name for name in _defnames if _defnames.count(name) > 1][0]}" ) for _def in _contract_defs: constant = False # test for valid call type keyword. if len(_def.body) == 1 and \ isinstance(_def.body[0], ast.Expr) and \ isinstance(_def.body[0].value, ast.Name) and \ _def.body[0].value.id in ('modifying', 'constant'): constant = True if _def.body[ 0].value.id == 'constant' else False else: raise StructureException( 'constant or modifying call type must be specified', _def) # Recognizes already-defined structs sig = FunctionSignature.from_definition( _def, contract_def=True, constant_override=constant, custom_structs=global_ctx._structs, constants=global_ctx._constants) contract[sig.name] = sig external_contracts[_contractname] = contract for interface_name, interface in global_ctx._interfaces.items(): external_contracts[interface_name] = { sig.name: sig for sig in interface if isinstance(sig, FunctionSignature) } return external_contracts
def mk_method_identifiers(code, interface_codes=None): from vyper.parser.parser import parse_to_ast o = {} global_ctx = GlobalContext.get_global_context( parse_to_ast(code), interface_codes=interface_codes, ) for code in global_ctx._defs: sig = FunctionSignature.from_definition( code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units, constants=global_ctx._constants, ) if not sig.private: default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx) for s in default_sigs: o[s.sig] = hex(s.method_id) return o
def parse_tree_to_lll(global_ctx: GlobalContext) -> Tuple[LLLnode, LLLnode]: _names_def = [_def.name for _def in global_ctx._defs] # Checks for duplicate function names if len(set(_names_def)) < len(_names_def): raise FunctionDeclarationException( "Duplicate function name: " f"{[name for name in _names_def if _names_def.count(name) > 1][0]}" ) _names_events = [_event.name for _event in global_ctx._events] # Checks for duplicate event names if len(set(_names_events)) < len(_names_events): raise EventDeclarationException(f"""Duplicate event name: {[name for name in _names_events if _names_events.count(name) > 1][0]}""" ) # Initialization function initfunc = [_def for _def in global_ctx._defs if is_initializer(_def)] # Default function defaultfunc = [_def for _def in global_ctx._defs if is_default_func(_def)] # Regular functions otherfuncs = [ _def for _def in global_ctx._defs if not is_initializer(_def) and not is_default_func(_def) ] # check if any functions in the contract are payable - if not, we do a single # ASSERT CALLVALUE ISZERO at the start of the bytecode rather than at the start # of each function is_contract_payable = next( (True for i in global_ctx._defs if FunctionSignature.from_definition( i, custom_structs=global_ctx._structs).mutability == "payable"), False, ) sigs: dict = {} external_interfaces: dict = {} # Create the main statement o = ["seq"] if global_ctx._contracts or global_ctx._interfaces: external_interfaces = parse_external_interfaces( external_interfaces, global_ctx) # If there is an init func... if initfunc: o.append(init_func_init_lll()) o.append( parse_function( initfunc[0], { **{ "self": sigs }, **external_interfaces }, global_ctx, False, )) # If there are regular functions... if otherfuncs or defaultfunc: o, runtime = parse_other_functions( o, otherfuncs, sigs, external_interfaces, global_ctx, defaultfunc, is_contract_payable, ) else: runtime = o.copy() if not is_contract_payable: # if no functions in the contract are payable, assert that callvalue is # zero at the beginning of the bytecode runtime.insert(1, ["assert", ["iszero", "callvalue"]]) return LLLnode.from_list(o, typ=None), LLLnode.from_list(runtime, typ=None)
def parse_func(code, sigs, origcode, global_ctx, _vars=None): if _vars is None: _vars = {} sig = FunctionSignature.from_definition( code, sigs=sigs, custom_units=global_ctx._custom_units, custom_structs=global_ctx._structs, constants=global_ctx._constants) # Get base args for function. total_default_args = len(code.args.defaults) base_args = sig.args[: -total_default_args] if total_default_args > 0 else sig.args default_args = code.args.args[-total_default_args:] default_values = dict( zip([arg.arg for arg in default_args], code.args.defaults)) # __init__ function may not have defaults. if sig.name == '__init__' and total_default_args > 0: raise FunctionDeclarationException( "__init__ function may not have default parameters.") # Check for duplicate variables with globals for arg in sig.args: if arg.name in global_ctx._globals: raise FunctionDeclarationException( "Variable name duplicated between function arguments and globals: " + arg.name) nonreentrant_pre = [['pass']] nonreentrant_post = [['pass']] if sig.nonreentrant_key: nkey = global_ctx.get_nonrentrant_counter(sig.nonreentrant_key) nonreentrant_pre = [[ 'seq', ['assert', ['iszero', ['sload', nkey]]], ['sstore', nkey, 1] ]] nonreentrant_post = [['sstore', nkey, 0]] # Create a local (per function) context. context = Context( vars=_vars, global_ctx=global_ctx, sigs=sigs, return_type=sig.output_type, constancy=Constancy.Constant if sig.const else Constancy.Mutable, is_payable=sig.payable, origcode=origcode, is_private=sig.private, method_id=sig.method_id) # Copy calldata to memory for fixed-size arguments max_copy_size = sum([ 32 if isinstance(arg.typ, ByteArrayLike) else get_size_of_type(arg.typ) * 32 for arg in sig.args ]) base_copy_size = sum([ 32 if isinstance(arg.typ, ByteArrayLike) else get_size_of_type(arg.typ) * 32 for arg in base_args ]) context.next_mem += max_copy_size clampers = [] # Create callback_ptr, this stores a destination in the bytecode for a private # function to jump to after a function has executed. _post_callback_ptr = "{}_{}_post_callback_ptr".format( sig.name, sig.method_id) if sig.private: context.callback_ptr = context.new_placeholder(typ=BaseType('uint256')) clampers.append( LLLnode.from_list( ['mstore', context.callback_ptr, 'pass'], annotation='pop callback pointer', )) if total_default_args > 0: clampers.append(['label', _post_callback_ptr]) # private functions without return types need to jump back to # the calling function, as there is no return statement to handle the # jump. stop_func = [['stop']] if sig.output_type is None and sig.private: stop_func = [['jump', ['mload', context.callback_ptr]]] if not len(base_args): copier = 'pass' elif sig.name == '__init__': copier = [ 'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen', base_copy_size ] else: copier = get_arg_copier(sig=sig, total_size=base_copy_size, memory_dest=MemoryPositions.RESERVED_MEMORY) clampers.append(copier) # Add asserts for payable and internal # private never gets payable check. if not sig.payable and not sig.private: clampers.append(['assert', ['iszero', 'callvalue']]) # Fill variable positions for i, arg in enumerate(sig.args): if i < len(base_args) and not sig.private: clampers.append( make_clamper( arg.pos, context.next_mem, arg.typ, sig.name == '__init__', )) if isinstance(arg.typ, ByteArrayLike): context.vars[arg.name] = VariableRecord(arg.name, context.next_mem, arg.typ, False) context.next_mem += 32 * get_size_of_type(arg.typ) else: context.vars[arg.name] = VariableRecord( arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False, ) # Private function copiers. No clamping for private functions. dyn_variable_names = [ a.name for a in base_args if isinstance(a.typ, ByteArrayLike) ] if sig.private and dyn_variable_names: i_placeholder = context.new_placeholder(typ=BaseType('uint256')) unpackers = [] for idx, var_name in enumerate(dyn_variable_names): var = context.vars[var_name] ident = "_load_args_%d_dynarg%d" % (sig.method_id, idx) o = make_unpacker(ident=ident, i_placeholder=i_placeholder, begin_pos=var.pos) unpackers.append(o) if not unpackers: unpackers = ['pass'] clampers.append( LLLnode.from_list( # [0] to complete full overarching 'seq' statement, see private_label. ['seq_unchecked'] + unpackers + [0], typ=None, annotation='dynamic unpacker', pos=getpos(code), )) # Create "clampers" (input well-formedness checkers) # Return function body if sig.name == '__init__': o = LLLnode.from_list( ['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code), ) elif is_default_func(sig): if len(sig.args) > 0: raise FunctionDeclarationException( 'Default function may not receive any arguments.', code) if sig.private: raise FunctionDeclarationException( 'Default function may only be public.', code, ) o = LLLnode.from_list( ['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code), ) else: if total_default_args > 0: # Function with default parameters. function_routine = "{}_{}".format(sig.name, sig.method_id) default_sigs = generate_default_arg_sigs(code, sigs, global_ctx) sig_chain = ['seq'] for default_sig in default_sigs: sig_compare, private_label = get_sig_statements( default_sig, getpos(code)) # Populate unset default variables populate_arg_count = len(sig.args) - len(default_sig.args) set_defaults = [] if populate_arg_count > 0: current_sig_arg_names = {x.name for x in default_sig.args} missing_arg_names = [ arg.arg for arg in default_args if arg.arg not in current_sig_arg_names ] for arg_name in missing_arg_names: value = Expr(default_values[arg_name], context).lll_node var = context.vars[arg_name] left = LLLnode.from_list(var.pos, typ=var.typ, location='memory', pos=getpos(code), mutable=var.mutable) set_defaults.append( make_setter(left, value, 'memory', pos=getpos(code))) current_sig_arg_names = {x.name for x in default_sig.args} base_arg_names = {arg.name for arg in base_args} if sig.private: # Load all variables in default section, if private, # because the stack is a linear pipe. copier_arg_count = len(default_sig.args) copier_arg_names = current_sig_arg_names else: copier_arg_count = len(default_sig.args) - len(base_args) copier_arg_names = current_sig_arg_names - base_arg_names # Order copier_arg_names, this is very important. copier_arg_names = [ x.name for x in default_sig.args if x.name in copier_arg_names ] # Variables to be populated from calldata/stack. default_copiers = [] if copier_arg_count > 0: # Get map of variables in calldata, with thier offsets offset = 4 calldata_offset_map = {} for arg in default_sig.args: calldata_offset_map[arg.name] = offset offset += (32 if isinstance(arg.typ, ByteArrayLike) else get_size_of_type(arg.typ) * 32) # Copy set default parameters from calldata dynamics = [] for arg_name in copier_arg_names: var = context.vars[arg_name] calldata_offset = calldata_offset_map[arg_name] if sig.private: _offset = calldata_offset if isinstance(var.typ, ByteArrayLike): _size = 32 dynamics.append(var.pos) else: _size = var.size * 32 default_copiers.append( get_arg_copier( sig=sig, memory_dest=var.pos, total_size=_size, offset=_offset, )) else: # Add clampers. default_copiers.append( make_clamper( calldata_offset - 4, var.pos, var.typ, )) # Add copying code. if isinstance(var.typ, ByteArrayLike): _offset = [ 'add', 4, ['calldataload', calldata_offset] ] else: _offset = calldata_offset default_copiers.append( get_arg_copier( sig=sig, memory_dest=var.pos, total_size=var.size * 32, offset=_offset, )) # Unpack byte array if necessary. if dynamics: i_placeholder = context.new_placeholder( typ=BaseType('uint256')) for idx, var_pos in enumerate(dynamics): ident = 'unpack_default_sig_dyn_%d_arg%d' % ( default_sig.method_id, idx) default_copiers.append( make_unpacker( ident=ident, i_placeholder=i_placeholder, begin_pos=var_pos, )) default_copiers.append(0) # for over arching seq, POP sig_chain.append([ 'if', sig_compare, [ 'seq', private_label, ['pass'] if not sig.private else LLLnode.from_list([ 'mstore', context.callback_ptr, 'pass', ], annotation='pop callback pointer', pos=getpos(code)), ['seq'] + set_defaults if set_defaults else ['pass'], ['seq_unchecked'] + default_copiers if default_copiers else ['pass'], [ 'goto', _post_callback_ptr if sig.private else function_routine ] ] ]) # With private functions all variable loading occurs in the default # function sub routine. if sig.private: _clampers = [['label', _post_callback_ptr]] else: _clampers = clampers # Function with default parameters. o = LLLnode.from_list( [ 'seq', sig_chain, [ 'if', 0, # can only be jumped into [ 'seq', ['label', function_routine] if not sig.private else ['pass'], ['seq'] + nonreentrant_pre + _clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + stop_func ], ], ], typ=None, pos=getpos(code)) else: # Function without default parameters. sig_compare, private_label = get_sig_statements(sig, getpos(code)) o = LLLnode.from_list([ 'if', sig_compare, ['seq'] + [private_label] + nonreentrant_pre + clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + stop_func ], typ=None, pos=getpos(code)) # Check for at leasts one return statement if necessary. if context.return_type and context.function_return_count == 0: raise FunctionDeclarationException( "Missing return statement in function '%s' " % sig.name, code) o.context = context o.total_gas = o.gas + calc_mem_gas(o.context.next_mem) o.func_name = sig.name return o
def parse_func(code, sigs, origcode, global_ctx, _vars=None): if _vars is None: _vars = {} sig = FunctionSignature.from_definition( code, sigs=sigs, custom_units=global_ctx._custom_units) # Get base args for function. total_default_args = len(code.args.defaults) base_args = sig.args[: -total_default_args] if total_default_args > 0 else sig.args default_args = code.args.args[-total_default_args:] default_values = dict( zip([arg.arg for arg in default_args], code.args.defaults)) # __init__ function may not have defaults. if sig.name == '__init__' and total_default_args > 0: raise FunctionDeclarationException( "__init__ function may not have default parameters.") # Check for duplicate variables with globals for arg in sig.args: if arg.name in global_ctx._globals: raise FunctionDeclarationException( "Variable name duplicated between function arguments and globals: " + arg.name) # Create a context context = Context(vars=_vars, globals=global_ctx._globals, sigs=sigs, return_type=sig.output_type, is_constant=sig.const, is_payable=sig.payable, origcode=origcode, custom_units=global_ctx._custom_units) # Copy calldata to memory for fixed-size arguments max_copy_size = sum([ 32 if isinstance(arg.typ, ByteArrayType) else get_size_of_type(arg.typ) * 32 for arg in sig.args ]) base_copy_size = sum([ 32 if isinstance(arg.typ, ByteArrayType) else get_size_of_type(arg.typ) * 32 for arg in base_args ]) context.next_mem += max_copy_size if not len(base_args): copier = 'pass' elif sig.name == '__init__': copier = [ 'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen', base_copy_size ] else: copier = [ 'calldatacopy', MemoryPositions.RESERVED_MEMORY, 4, base_copy_size ] clampers = [copier] # Add asserts for payable and internal if not sig.payable: clampers.append(['assert', ['iszero', 'callvalue']]) if sig.private: clampers.append(['assert', ['eq', 'caller', 'address']]) # Fill variable positions for i, arg in enumerate(sig.args): if i < len(base_args): clampers.append( make_clamper(arg.pos, context.next_mem, arg.typ, sig.name == '__init__')) if isinstance(arg.typ, ByteArrayType): context.vars[arg.name] = VariableRecord(arg.name, context.next_mem, arg.typ, False) context.next_mem += 32 * get_size_of_type(arg.typ) else: context.vars[arg.name] = VariableRecord( arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False) # Create "clampers" (input well-formedness checkers) # Return function body if sig.name == '__init__': o = LLLnode.from_list(['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code)) elif is_default_func(sig): if len(sig.args) > 0: raise FunctionDeclarationException( 'Default function may not receive any arguments.', code) if sig.private: raise FunctionDeclarationException( 'Default function may only be public.', code) o = LLLnode.from_list(['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code)) else: # Handle default args if present. function_routine = "{}_{}".format(sig.name, sig.method_id) if total_default_args > 0: default_sigs = generate_default_arg_sigs(code, sigs, global_ctx._custom_units) sig_chain = ['seq'] for default_sig_idx, default_sig in enumerate(default_sigs): method_id_node = LLLnode.from_list(default_sig.method_id, pos=getpos(code), annotation='%s' % default_sig.sig) # Populate unset default variables populate_arg_count = len(sig.args) - len(default_sig.args) set_defaults = [] if populate_arg_count > 0: current_sig_arg_names = {x.name for x in default_sig.args} missing_arg_names = [ arg.arg for arg in default_args if arg.arg not in current_sig_arg_names ] for arg_name in missing_arg_names: value = Expr(default_values[arg_name], context).lll_node var = context.vars[arg_name] left = LLLnode.from_list(var.pos, typ=var.typ, location='memory', pos=getpos(code), mutable=var.mutable) set_defaults.append( make_setter(left, value, 'memory', pos=getpos(code))) # Variables to be populated from calldata copier_arg_count = len(default_sig.args) - len(base_args) default_copiers = [] if copier_arg_count > 0: current_sig_arg_names = {x.name for x in default_sig.args} base_arg_names = {arg.name for arg in base_args} copier_arg_names = current_sig_arg_names - base_arg_names # Get map of variables in calldata, with thier offsets offset = 4 calldata_offset_map = {} for arg in default_sig.args: calldata_offset_map[arg.name] = offset offset += 32 if isinstance( arg.typ, ByteArrayType) else get_size_of_type(arg.typ) * 32 # Copy set default parameters from calldata for arg_name in copier_arg_names: var = context.vars[arg_name] calldata_offset = calldata_offset_map[arg_name] # Add clampers. default_copiers.append( make_clamper(calldata_offset - 4, var.pos, var.typ)) # Add copying code. if isinstance(var.typ, ByteArrayType): default_copiers.append([ 'calldatacopy', var.pos, ['add', 4, ['calldataload', calldata_offset]], var.size * 32 ]) else: default_copiers.append([ 'calldatacopy', var.pos, calldata_offset, var.size * 32 ]) sig_chain.append([ 'if', ['eq', ['mload', 0], method_id_node], [ 'seq', ['seq'] + set_defaults if set_defaults else ['pass'], ['seq'] + default_copiers if default_copiers else ['pass'], ['goto', function_routine] ] ]) o = LLLnode.from_list( [ 'seq', sig_chain, [ 'if', 0, # can only be jumped into [ 'seq', ['label', function_routine ], ['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop'] ] ] ], typ=None, pos=getpos(code)) else: # Function without default parameters. method_id_node = LLLnode.from_list(sig.method_id, pos=getpos(code), annotation='%s' % sig.sig) o = LLLnode.from_list([ 'if', ['eq', ['mload', 0], method_id_node], ['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop'] ], typ=None, pos=getpos(code)) # Check for at leasts one return statement if necessary. if context.return_type and context.function_return_count == 0: raise FunctionDeclarationException( "Missing return statement in function '%s' " % sig.name, code) o.context = context o.total_gas = o.gas + calc_mem_gas(o.context.next_mem) o.func_name = sig.name return o
def call(self): from .parser import ( external_contract_call, pack_arguments, ) from vyper.functions import ( dispatch_table, ) if isinstance(self.expr.func, ast.Name): function_name = self.expr.func.id if function_name in dispatch_table: return dispatch_table[function_name](self.expr, self.context) else: err_msg = "Not a top-level function: {}".format(function_name) if function_name in [x.split('(')[0] for x, _ in self.context.sigs['self'].items()]: err_msg += ". Did you mean self.{}?".format(function_name) raise StructureException(err_msg, self.expr) elif isinstance(self.expr.func, ast.Attribute) and isinstance(self.expr.func.value, ast.Name) and self.expr.func.value.id == "self": expr_args = [Expr(arg, self.context).lll_node for arg in self.expr.args] method_name = self.expr.func.attr sig = FunctionSignature.lookup_sig(self.context.sigs, method_name, expr_args, self.expr, self.context) if self.context.is_constant and not sig.const: raise ConstancyViolationException( "May not call non-constant function '%s' within a constant function." % (method_name), getpos(self.expr) ) add_gas = sig.gas # gas of call inargs, inargsize = pack_arguments(sig, expr_args, self.context, pos=getpos(self.expr)) output_placeholder = self.context.new_placeholder(typ=sig.output_type) multi_arg = [] if isinstance(sig.output_type, BaseType): returner = output_placeholder elif isinstance(sig.output_type, ByteArrayType): returner = output_placeholder + 32 elif isinstance(sig.output_type, TupleType): returner = output_placeholder else: raise TypeMismatchException("Invalid output type: %r" % sig.output_type, self.expr) o = LLLnode.from_list(multi_arg + ['seq', ['assert', ['call', ['gas'], ['address'], 0, inargs, inargsize, output_placeholder, get_size_of_type(sig.output_type) * 32]], returner], typ=sig.output_type, location='memory', pos=getpos(self.expr), add_gas_estimate=add_gas, annotation='Internal Call: %s' % method_name) o.gas += sig.gas return o elif isinstance(self.expr.func, ast.Attribute) and isinstance(self.expr.func.value, ast.Call): contract_name = self.expr.func.value.func.id contract_address = Expr.parse_value_expr(self.expr.func.value.args[0], self.context) value, gas = self._get_external_contract_keywords() return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas) elif isinstance(self.expr.func.value, ast.Attribute) and self.expr.func.value.attr in self.context.sigs: contract_name = self.expr.func.value.attr var = self.context.globals[self.expr.func.value.attr] contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.expr), annotation='self.' + self.expr.func.value.attr)) value, gas = self._get_external_contract_keywords() return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas) elif isinstance(self.expr.func.value, ast.Attribute) and self.expr.func.value.attr in self.context.globals: contract_name = self.context.globals[self.expr.func.value.attr].typ.unit var = self.context.globals[self.expr.func.value.attr] contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.expr), annotation='self.' + self.expr.func.value.attr)) value, gas = self._get_external_contract_keywords() return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas) else: raise StructureException("Unsupported operator: %r" % ast.dump(self.expr), self.expr)
def parse_public_function(code: ast.FunctionDef, sig: FunctionSignature, context: Context) -> LLLnode: """ Parse a public function (FuncDef), and produce full function body. :param sig: the FuntionSignature :param code: ast of function :return: full sig compare & function body """ validate_public_function(code, sig, context.global_ctx) # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock( sig, context.global_ctx) clampers = [] # Generate copiers copier: List[Any] = ['pass'] if not len(sig.base_args): copier = ['pass'] elif sig.name == '__init__': copier = [ 'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen', sig.base_copy_size ] context.memory_allocator.increase_memory(sig.max_copy_size) clampers.append(copier) # Add asserts for payable and internal if not sig.payable: clampers.append(['assert', ['iszero', 'callvalue']]) # Fill variable positions default_args_start_pos = len(sig.base_args) for i, arg in enumerate(sig.args): if i < len(sig.base_args): clampers.append( make_arg_clamper( arg.pos, context.memory_allocator.get_next_memory_position(), arg.typ, sig.name == '__init__', )) if isinstance(arg.typ, ByteArrayLike): mem_pos, _ = context.memory_allocator.increase_memory( 32 * get_size_of_type(arg.typ)) context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ, False) else: if sig.name == '__init__': context.vars[arg.name] = VariableRecord( arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False, ) elif i >= default_args_start_pos: # default args need to be allocated in memory. default_arg_pos, _ = context.memory_allocator.increase_memory( 32) context.vars[arg.name] = VariableRecord( name=arg.name, pos=default_arg_pos, typ=arg.typ, mutable=False, ) else: context.vars[arg.name] = VariableRecord(name=arg.name, pos=4 + arg.pos, typ=arg.typ, mutable=False, location='calldata') # Create "clampers" (input well-formedness checkers) # Return function body if sig.name == '__init__': o = LLLnode.from_list( ['seq'] + clampers + [parse_body(code.body, context)], # type: ignore pos=getpos(code), ) # Is default function. elif sig.is_default_func(): if len(sig.args) > 0: raise FunctionDeclarationException( 'Default function may not receive any arguments.', code) o = LLLnode.from_list( ['seq'] + clampers + [parse_body(code.body, context)], # type: ignore pos=getpos(code), ) # Is a normal function. else: # Function with default parameters. if sig.total_default_args > 0: function_routine = f"{sig.name}_{sig.method_id}" default_sigs = sig_utils.generate_default_arg_sigs( code, context.sigs, context.global_ctx) sig_chain: List[Any] = ['seq'] for default_sig in default_sigs: sig_compare, _ = get_sig_statements(default_sig, getpos(code)) # Populate unset default variables set_defaults = [] for arg_name in get_default_names_to_set(sig, default_sig): value = Expr(sig.default_values[arg_name], context).lll_node var = context.vars[arg_name] left = LLLnode.from_list(var.pos, typ=var.typ, location='memory', pos=getpos(code), mutable=var.mutable) set_defaults.append( make_setter(left, value, 'memory', pos=getpos(code))) current_sig_arg_names = {x.name for x in default_sig.args} base_arg_names = {arg.name for arg in sig.base_args} copier_arg_count = len(default_sig.args) - len(sig.base_args) copier_arg_names = list(current_sig_arg_names - base_arg_names) # Order copier_arg_names, this is very important. copier_arg_names = [ x.name for x in default_sig.args if x.name in copier_arg_names ] # Variables to be populated from calldata/stack. default_copiers: List[Any] = [] if copier_arg_count > 0: # Get map of variables in calldata, with thier offsets offset = 4 calldata_offset_map = {} for arg in default_sig.args: calldata_offset_map[arg.name] = offset offset += (32 if isinstance(arg.typ, ByteArrayLike) else get_size_of_type(arg.typ) * 32) # Copy default parameters from calldata. for arg_name in copier_arg_names: var = context.vars[arg_name] calldata_offset = calldata_offset_map[arg_name] # Add clampers. default_copiers.append( make_arg_clamper( calldata_offset - 4, var.pos, var.typ, )) # Add copying code. _offset: Union[int, List[Any]] = calldata_offset if isinstance(var.typ, ByteArrayLike): _offset = [ 'add', 4, ['calldataload', calldata_offset] ] default_copiers.append( get_public_arg_copier( memory_dest=var.pos, total_size=var.size * 32, offset=_offset, )) default_copiers.append(0) # for over arching seq, POP sig_chain.append([ 'if', sig_compare, [ 'seq', ['seq'] + set_defaults if set_defaults else ['pass'], ['seq_unchecked'] + default_copiers if default_copiers else ['pass'], ['goto', function_routine] ] ]) # Function with default parameters. o = LLLnode.from_list( [ 'seq', sig_chain, [ 'if', 0, # can only be jumped into [ 'seq', ['label', function_routine ], ['seq'] + nonreentrant_pre + clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + [['stop']] ], ], ], typ=None, pos=getpos(code)) else: # Function without default parameters. sig_compare, _ = get_sig_statements(sig, getpos(code)) o = LLLnode.from_list([ 'if', sig_compare, ['seq'] + nonreentrant_pre + clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + [['stop']] ], typ=None, pos=getpos(code)) return o
def call(self): from .parser import ( pack_arguments, pack_logging_data, pack_logging_topics, external_contract_call, ) if isinstance(self.stmt.func, ast.Name): if self.stmt.func.id in stmt_dispatch_table: return stmt_dispatch_table[self.stmt.func.id](self.stmt, self.context) elif self.stmt.func.id in dispatch_table: raise StructureException("Function {} can not be called without being used.".format(self.stmt.func.id), self.stmt) else: raise StructureException("Unknown function: '{}'.".format(self.stmt.func.id), self.stmt) elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Name) and self.stmt.func.value.id == "self": method_name = self.stmt.func.attr expr_args = [Expr(arg, self.context).lll_node for arg in self.stmt.args] # full_sig = FunctionSignature.get_full_sig(method_name, expr_args, self.context.sigs, self.context.custom_units) sig = FunctionSignature.lookup_sig(self.context.sigs, method_name, expr_args, self.stmt, self.context) if self.context.is_constant and not sig.const: raise ConstancyViolationException( "May not call non-constant function '%s' within a constant function." % (sig.sig) ) add_gas = self.context.sigs['self'][sig.sig].gas inargs, inargsize = pack_arguments(sig, expr_args, self.context, pos=getpos(self.stmt)) return LLLnode.from_list(['assert', ['call', ['gas'], ['address'], 0, inargs, inargsize, 0, 0]], typ=None, pos=getpos(self.stmt), add_gas_estimate=add_gas, annotation='Internal Call: %s' % sig.sig) elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Call): contract_name = self.stmt.func.value.func.id contract_address = Expr.parse_value_expr(self.stmt.func.value.args[0], self.context) return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt)) elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.sigs: contract_name = self.stmt.func.value.attr var = self.context.globals[self.stmt.func.value.attr] contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr)) return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt)) elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.globals: contract_name = self.context.globals[self.stmt.func.value.attr].typ.unit var = self.context.globals[self.stmt.func.value.attr] contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr)) return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt)) elif isinstance(self.stmt.func, ast.Attribute) and self.stmt.func.value.id == 'log': if self.stmt.func.attr not in self.context.sigs['self']: raise EventDeclarationException("Event not declared yet: %s" % self.stmt.func.attr) event = self.context.sigs['self'][self.stmt.func.attr] if len(event.indexed_list) != len(self.stmt.args): raise EventDeclarationException("%s received %s arguments but expected %s" % (event.name, len(self.stmt.args), len(event.indexed_list))) expected_topics, topics = [], [] expected_data, data = [], [] for pos, is_indexed in enumerate(event.indexed_list): if is_indexed: expected_topics.append(event.args[pos]) topics.append(self.stmt.args[pos]) else: expected_data.append(event.args[pos]) data.append(self.stmt.args[pos]) topics = pack_logging_topics(event.event_id, topics, expected_topics, self.context, pos=getpos(self.stmt)) inargs, inargsize, inargsize_node, inarg_start = pack_logging_data(expected_data, data, self.context, pos=getpos(self.stmt)) if inargsize_node is None: sz = inargsize else: sz = ['mload', inargsize_node] return LLLnode.from_list(['seq', inargs, LLLnode.from_list(["log" + str(len(topics)), inarg_start, sz] + topics, add_gas_estimate=inargsize * 10)], typ=None, pos=getpos(self.stmt)) else: raise StructureException("Unsupported operator: %r" % ast.dump(self.stmt), self.stmt)
def make_call(stmt_expr, context): # ** Internal Call ** # Steps: # (x) push current local variables # (x) push arguments # (x) push jumpdest (callback ptr) # (x) jump to label # (x) pop return values # (x) pop local variables pop_local_vars = [] push_local_vars = [] pop_return_values = [] push_args = [] method_name = stmt_expr.func.attr from vyper.parser.expr import parse_sequence pre_init, expr_args = parse_sequence(stmt_expr, stmt_expr.args, context) sig = FunctionSignature.lookup_sig( context.sigs, method_name, expr_args, stmt_expr, context, ) if context.is_constant() and sig.mutability not in ("view", "pure"): raise StateAccessViolation( f"May not call state modifying function " f"'{method_name}' within {context.pp_constancy()}.", getpos(stmt_expr), ) if not sig.internal: raise StructureException("Cannot call external functions via 'self'", stmt_expr) # Push local variables. var_slots = [(v.pos, v.size) for name, v in context.vars.items() if v.location == "memory"] if var_slots: var_slots.sort(key=lambda x: x[0]) if len(var_slots) > 10: # if memory is large enough, push and pop it via iteration mem_from, mem_to = var_slots[0][ 0], var_slots[-1][0] + var_slots[-1][1] * 32 i_placeholder = context.new_internal_variable(BaseType("uint256")) local_save_ident = f"_{stmt_expr.lineno}_{stmt_expr.col_offset}" push_loop_label = "save_locals_start" + local_save_ident pop_loop_label = "restore_locals_start" + local_save_ident push_local_vars = [ ["mstore", i_placeholder, mem_from], ["label", push_loop_label], ["mload", ["mload", i_placeholder]], [ "mstore", i_placeholder, ["add", ["mload", i_placeholder], 32] ], [ "if", ["lt", ["mload", i_placeholder], mem_to], ["goto", push_loop_label] ], ] pop_local_vars = [ ["mstore", i_placeholder, mem_to - 32], ["label", pop_loop_label], ["mstore", ["mload", i_placeholder], "pass"], [ "mstore", i_placeholder, ["sub", ["mload", i_placeholder], 32] ], [ "if", ["ge", ["mload", i_placeholder], mem_from], ["goto", pop_loop_label] ], ] else: # for smaller memory, hardcode the mload/mstore locations push_mem_slots = [] for pos, size in var_slots: push_mem_slots.extend([pos + i * 32 for i in range(size)]) push_local_vars = [["mload", pos] for pos in push_mem_slots] pop_local_vars = [["mstore", pos, "pass"] for pos in push_mem_slots[::-1]] # Push Arguments if expr_args: inargs, inargsize, arg_pos = pack_arguments(sig, expr_args, context, stmt_expr, is_external_call=False) push_args += [ inargs ] # copy arguments first, to not mess up the push/pop sequencing. static_arg_size = 32 * sum( [get_static_size_of_type(arg.typ) for arg in expr_args]) static_pos = int(arg_pos + static_arg_size) needs_dyn_section = any( [has_dynamic_data(arg.typ) for arg in expr_args]) if needs_dyn_section: ident = f"push_args_{sig.method_id}_{stmt_expr.lineno}_{stmt_expr.col_offset}" start_label = ident + "_start" end_label = ident + "_end" i_placeholder = context.new_internal_variable(BaseType("uint256")) # Calculate copy start position. # Given | static | dynamic | section in memory, # copy backwards so the values are in order on the stack. # We calculate i, the end of the whole encoded part # (i.e. the starting index for copy) # by taking ceil32(len<arg>) + offset<arg> + arg_pos # for the last dynamic argument and arg_pos is the start # the whole argument section. idx = 0 for arg in expr_args: if isinstance(arg.typ, ByteArrayLike): last_idx = idx idx += get_static_size_of_type(arg.typ) push_args += [[ "with", "offset", ["mload", arg_pos + last_idx * 32], [ "with", "len_pos", ["add", arg_pos, "offset"], [ "with", "len_value", ["mload", "len_pos"], [ "mstore", i_placeholder, ["add", "len_pos", ["ceil32", "len_value"]] ], ], ], ]] # loop from end of dynamic section to start of dynamic section, # pushing each element onto the stack. push_args += [ ["label", start_label], [ "if", ["lt", ["mload", i_placeholder], static_pos], ["goto", end_label] ], ["mload", ["mload", i_placeholder]], [ "mstore", i_placeholder, ["sub", ["mload", i_placeholder], 32] ], # decrease i ["goto", start_label], ["label", end_label], ] # push static section push_args += [["mload", pos] for pos in reversed(range(arg_pos, static_pos, 32))] elif sig.args: raise StructureException( f"Wrong number of args for: {sig.name} (0 args given, expected {len(sig.args)})", stmt_expr, ) # Jump to function label. jump_to_func = [ ["add", ["pc"], 6], # set callback pointer. ["goto", f"priv_{sig.method_id}"], ["jumpdest"], ] # Pop return values. returner = [0] if sig.output_type: output_placeholder, returner, output_size = _call_make_placeholder( stmt_expr, context, sig) if output_size > 0: dynamic_offsets = [] if isinstance(sig.output_type, (BaseType, ListType)): pop_return_values = [[ "mstore", ["add", output_placeholder, pos], "pass" ] for pos in range(0, output_size, 32)] elif isinstance(sig.output_type, ByteArrayLike): dynamic_offsets = [(0, sig.output_type)] pop_return_values = [ ["pop", "pass"], ] elif isinstance(sig.output_type, TupleLike): static_offset = 0 pop_return_values = [] for name, typ in sig.output_type.tuple_items(): if isinstance(typ, ByteArrayLike): pop_return_values.append([ "mstore", ["add", output_placeholder, static_offset], "pass" ]) dynamic_offsets.append(([ "mload", ["add", output_placeholder, static_offset] ], name)) static_offset += 32 else: member_output_size = get_size_of_type(typ) * 32 pop_return_values.extend([[ "mstore", ["add", output_placeholder, pos], "pass" ] for pos in range(static_offset, static_offset + member_output_size, 32)]) static_offset += member_output_size # append dynamic unpacker. dyn_idx = 0 for in_memory_offset, _out_type in dynamic_offsets: ident = f"{stmt_expr.lineno}_{stmt_expr.col_offset}_arg_{dyn_idx}" dyn_idx += 1 start_label = "dyn_unpack_start_" + ident end_label = "dyn_unpack_end_" + ident i_placeholder = context.new_internal_variable( typ=BaseType("uint256")) begin_pos = ["add", output_placeholder, in_memory_offset] # loop until length. o = LLLnode.from_list( [ "seq_unchecked", ["mstore", begin_pos, "pass"], # get len ["mstore", i_placeholder, 0], ["label", start_label], [ # break "if", [ "ge", ["mload", i_placeholder], ["ceil32", ["mload", begin_pos]] ], ["goto", end_label], ], [ # pop into correct memory slot. "mstore", [ "add", ["add", begin_pos, 32], ["mload", i_placeholder] ], "pass", ], # increment i [ "mstore", i_placeholder, ["add", 32, ["mload", i_placeholder]] ], ["goto", start_label], ["label", end_label], ], typ=None, annotation="dynamic unpacker", pos=getpos(stmt_expr), ) pop_return_values.append(o) call_body = list( itertools.chain( ["seq_unchecked"], pre_init, push_local_vars, push_args, jump_to_func, pop_return_values, pop_local_vars, [returner], )) # If we have no return, we need to pop off pop_returner_call_body = ["pop", call_body ] if sig.output_type is None else call_body o = LLLnode.from_list( pop_returner_call_body, typ=sig.output_type, location="memory", pos=getpos(stmt_expr), annotation=f"Internal Call: {method_name}", add_gas_estimate=sig.gas, ) o.gas += sig.gas return o
def parse_external_function( code: vy_ast.FunctionDef, sig: FunctionSignature, context: Context, check_nonpayable: bool, ) -> LLLnode: """ Parse a external function (FuncDef), and produce full function body. :param sig: the FuntionSignature :param code: ast of function :param check_nonpayable: if True, include a check that `msg.value == 0` at the beginning of the function :return: full sig compare & function body """ func_type = code._metadata["type"] # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock( func_type, context.global_ctx) clampers = [] # Generate copiers copier: List[Any] = ["pass"] if not len(sig.base_args): copier = ["pass"] elif sig.name == "__init__": copier = [ "codecopy", MemoryPositions.RESERVED_MEMORY, "~codelen", sig.base_copy_size ] context.memory_allocator.expand_memory(sig.max_copy_size) clampers.append(copier) if check_nonpayable and sig.mutability != "payable": # if the contract contains payable functions, but this is not one of them # add an assertion that the value of the call is zero clampers.append(["assert", ["iszero", "callvalue"]]) # Fill variable positions default_args_start_pos = len(sig.base_args) for i, arg in enumerate(sig.args): if i < len(sig.base_args): clampers.append( make_arg_clamper( arg.pos, context.memory_allocator.get_next_memory_position(), arg.typ, sig.name == "__init__", )) if isinstance(arg.typ, ByteArrayLike): mem_pos = context.memory_allocator.expand_memory( 32 * get_size_of_type(arg.typ)) context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ, False) else: if sig.name == "__init__": context.vars[arg.name] = VariableRecord( arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False, ) elif i >= default_args_start_pos: # default args need to be allocated in memory. type_size = get_size_of_type(arg.typ) * 32 default_arg_pos = context.memory_allocator.expand_memory( type_size) context.vars[arg.name] = VariableRecord( name=arg.name, pos=default_arg_pos, typ=arg.typ, mutable=False, ) else: context.vars[arg.name] = VariableRecord(name=arg.name, pos=4 + arg.pos, typ=arg.typ, mutable=False, location="calldata") # Create "clampers" (input well-formedness checkers) # Return function body if sig.name == "__init__": o = LLLnode.from_list( ["seq"] + clampers + [parse_body(code.body, context)], # type: ignore pos=getpos(code), ) # Is default function. elif sig.is_default_func(): o = LLLnode.from_list( ["seq"] + clampers + [parse_body(code.body, context)] + [["stop"]], # type: ignore pos=getpos(code), ) # Is a normal function. else: # Function with default parameters. if sig.total_default_args > 0: function_routine = f"{sig.name}_{sig.method_id}" default_sigs = sig_utils.generate_default_arg_sigs( code, context.sigs, context.global_ctx) sig_chain: List[Any] = ["seq"] for default_sig in default_sigs: sig_compare, _ = get_sig_statements(default_sig, getpos(code)) # Populate unset default variables set_defaults = [] for arg_name in get_default_names_to_set(sig, default_sig): value = Expr(sig.default_values[arg_name], context).lll_node var = context.vars[arg_name] left = LLLnode.from_list( var.pos, typ=var.typ, location="memory", pos=getpos(code), mutable=var.mutable, ) set_defaults.append( make_setter(left, value, "memory", pos=getpos(code))) current_sig_arg_names = {x.name for x in default_sig.args} base_arg_names = {arg.name for arg in sig.base_args} copier_arg_count = len(default_sig.args) - len(sig.base_args) copier_arg_names = list(current_sig_arg_names - base_arg_names) # Order copier_arg_names, this is very important. copier_arg_names = [ x.name for x in default_sig.args if x.name in copier_arg_names ] # Variables to be populated from calldata/stack. default_copiers: List[Any] = [] if copier_arg_count > 0: # Get map of variables in calldata, with thier offsets offset = 4 calldata_offset_map = {} for arg in default_sig.args: calldata_offset_map[arg.name] = offset offset += (32 if isinstance(arg.typ, ByteArrayLike) else get_size_of_type(arg.typ) * 32) # Copy default parameters from calldata. for arg_name in copier_arg_names: var = context.vars[arg_name] calldata_offset = calldata_offset_map[arg_name] # Add clampers. default_copiers.append( make_arg_clamper( calldata_offset - 4, var.pos, var.typ, )) # Add copying code. _offset: Union[int, List[Any]] = calldata_offset if isinstance(var.typ, ByteArrayLike): _offset = [ "add", 4, ["calldataload", calldata_offset] ] default_copiers.append( get_external_arg_copier( memory_dest=var.pos, total_size=var.size * 32, offset=_offset, )) default_copiers.append(0) # for over arching seq, POP sig_chain.append([ "if", sig_compare, [ "seq", ["seq"] + set_defaults if set_defaults else ["pass"], ["seq_unchecked"] + default_copiers if default_copiers else ["pass"], ["goto", function_routine], ], ]) # Function with default parameters. function_jump_label = f"{sig.name}_{sig.method_id}_skip" o = LLLnode.from_list( [ "seq", sig_chain, [ "seq", ["goto", function_jump_label], ["label", function_routine], ["seq"] + nonreentrant_pre + clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + [["stop"]], ["label", function_jump_label], ], ], typ=None, pos=getpos(code), ) else: # Function without default parameters. sig_compare, _ = get_sig_statements(sig, getpos(code)) o = LLLnode.from_list( [ "if", sig_compare, ["seq"] + nonreentrant_pre + clampers + [parse_body(c, context) for c in code.body] + nonreentrant_post + [["stop"]], ], typ=None, pos=getpos(code), ) return o
def parse_func(code, _globals, sigs, origcode, _custom_units, _vars=None): if _vars is None: _vars = {} sig = FunctionSignature.from_definition(code, sigs=sigs, custom_units=_custom_units) # Check for duplicate variables with globals for arg in sig.args: if arg.name in _globals: raise FunctionDeclarationException("Variable name duplicated between function arguments and globals: " + arg.name) # Create a context context = Context(vars=_vars, globals=_globals, sigs=sigs, return_type=sig.output_type, is_constant=sig.const, is_payable=sig.payable, origcode=origcode, custom_units=_custom_units) # Copy calldata to memory for fixed-size arguments copy_size = sum([32 if isinstance(arg.typ, ByteArrayType) else get_size_of_type(arg.typ) * 32 for arg in sig.args]) context.next_mem += copy_size if not len(sig.args): copier = 'pass' elif sig.name == '__init__': copier = ['codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen', copy_size] else: copier = ['calldatacopy', MemoryPositions.RESERVED_MEMORY, 4, copy_size] clampers = [copier] # Add asserts for payable and internal if not sig.payable: clampers.append(['assert', ['iszero', 'callvalue']]) if sig.private: clampers.append(['assert', ['eq', 'caller', 'address']]) # Fill in variable positions for arg in sig.args: clampers.append(make_clamper(arg.pos, context.next_mem, arg.typ, sig.name == '__init__')) if isinstance(arg.typ, ByteArrayType): context.vars[arg.name] = VariableRecord(arg.name, context.next_mem, arg.typ, False) context.next_mem += 32 * get_size_of_type(arg.typ) else: context.vars[arg.name] = VariableRecord(arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False) # Create "clampers" (input well-formedness checkers) # Return function body if sig.name == '__init__': o = LLLnode.from_list(['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code)) elif is_default_func(sig): if len(sig.args) > 0: raise FunctionDeclarationException('Default function may not receive any arguments.', code) if sig.private: raise FunctionDeclarationException('Default function may only be public.', code) o = LLLnode.from_list(['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code)) else: method_id_node = LLLnode.from_list(sig.method_id, pos=getpos(code), annotation='%s' % sig.name) o = LLLnode.from_list(['if', ['eq', ['mload', 0], method_id_node], ['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop'] ], typ=None, pos=getpos(code)) # Check for at leasts one return statement if necessary. if context.return_type and context.function_return_count == 0: raise FunctionDeclarationException( "Missing return statement in function '%s' " % sig.name, code ) o.context = context o.total_gas = o.gas + calc_mem_gas(o.context.next_mem) o.func_name = sig.name return o