Exemple #1
0
def generate_default_arg_sigs(code, _contracts, _custom_units):
    # generate all sigs, and attach.
    total_default_args = len(code.args.defaults)
    if total_default_args == 0:
        return [FunctionSignature.from_definition(code, sigs=_contracts, custom_units=_custom_units)]
    base_args = code.args.args[:-total_default_args]
    default_args = code.args.args[-total_default_args:]

    # Generate a list of default function combinations.
    row = [False] * (total_default_args)
    table = [row.copy()]
    for i in range(total_default_args):
        row[i] = True
        table.append(row.copy())

    default_sig_strs = []
    sig_fun_defs = []
    for truth_row in table:
        new_code = copy.deepcopy(code)
        new_code.args.args = copy.deepcopy(base_args)
        new_code.args.default = []
        # Add necessary default args.
        for idx, val in enumerate(truth_row):
            if val is True:
                new_code.args.args.append(default_args[idx])
        sig = FunctionSignature.from_definition(new_code, sigs=_contracts, custom_units=_custom_units)
        default_sig_strs.append(sig.sig)
        sig_fun_defs.append(sig)

    return sig_fun_defs
Exemple #2
0
def mk_full_signature(global_ctx, sig_formatter=None):

    if sig_formatter is None:
        # Use default JSON style output.
        sig_formatter = _default_sig_formatter

    o = []

    # Produce event signatues.
    for code in global_ctx._events:
        sig = EventSignature.from_declaration(code, global_ctx)
        o.append(sig_formatter(sig))

    # Produce function signatures.
    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_structs=global_ctx._structs,
            constants=global_ctx._constants,
        )
        if not sig.internal:
            default_sigs = generate_default_arg_sigs(code,
                                                     global_ctx._contracts,
                                                     global_ctx)
            for s in default_sigs:
                o.append(sig_formatter(s))
    return o
Exemple #3
0
def call_lookup_specs(stmt_expr, context):
    from vyper.parser.expr import Expr
    method_name = stmt_expr.func.attr
    expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args]
    sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args,
                                       stmt_expr, context)
    return method_name, expr_args, sig
Exemple #4
0
def parse_external_contracts(external_contracts, _contracts, _structs):
    for _contractname in _contracts:
        _contract_defs = _contracts[_contractname]
        _defnames = [_def.name for _def in _contract_defs]
        contract = {}
        if len(set(_defnames)) < len(_contract_defs):
            raise FunctionDeclarationException(
                "Duplicate function name: %s" %
                [name for name in _defnames if _defnames.count(name) > 1][0])

        for _def in _contract_defs:
            constant = False
            # test for valid call type keyword.
            if len(_def.body) == 1 and \
               isinstance(_def.body[0], ast.Expr) and \
               isinstance(_def.body[0].value, ast.Name) and \
               _def.body[0].value.id in ('modifying', 'constant'):
                constant = True if _def.body[
                    0].value.id == 'constant' else False
            else:
                raise StructureException(
                    'constant or modifying call type must be specified', _def)
            # Recognizes already-defined structs
            sig = FunctionSignature.from_definition(_def,
                                                    contract_def=True,
                                                    constant=constant,
                                                    custom_structs=_structs)
            contract[sig.name] = sig
        external_contracts[_contractname] = contract
    return external_contracts
Exemple #5
0
def mk_full_signature(code, sig_formatter=None, interface_codes=None):

    if sig_formatter is None:
        # Use default JSON style output.
        sig_formatter = _default_sig_formatter

    o = []
    global_ctx = GlobalContext.get_global_context(
        code, interface_codes=interface_codes)

    # Produce event signatues.
    for code in global_ctx._events:
        sig = EventSignature.from_declaration(code, global_ctx)
        o.append(sig_formatter(sig, global_ctx._custom_units_descriptions))

    # Produce function signatures.
    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_units=global_ctx._custom_units,
            custom_structs=global_ctx._structs,
            constants=global_ctx._constants)
        if not sig.private:
            default_sigs = generate_default_arg_sigs(code,
                                                     global_ctx._contracts,
                                                     global_ctx)
            for s in default_sigs:
                o.append(
                    sig_formatter(s, global_ctx._custom_units_descriptions))
    return o
def validate_external_function(code: ast.FunctionDef, sig: FunctionSignature,
                               global_ctx: GlobalContext) -> None:
    """ Validate external function definition. """

    # __init__ function may not have defaults.
    if sig.is_initializer() and sig.total_default_args > 0:
        raise FunctionDeclarationException(
            "__init__ function may not have default parameters.", code)
Exemple #7
0
def mk_full_signature(code):
    o = []
    _contracts, _events, _defs, _globals, _custom_units = get_contracts_and_defs_and_globals(code)
    for code in _events:
        sig = EventSignature.from_declaration(code, custom_units=_custom_units)
        o.append(sig.to_abi_dict())
    for code in _defs:
        sig = FunctionSignature.from_definition(code, sigs=_contracts, custom_units=_custom_units)
        if not sig.private:
            o.append(sig.to_abi_dict())
    return o
Exemple #8
0
def mk_method_identifiers(code):
    o = {}
    global_ctx = GlobalContext.get_global_context(parse(code))

    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units, constants=global_ctx._constants)
        if not sig.private:
            default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx)
            for s in default_sigs:
                o[s.sig] = hex(s.method_id)

    return o
Exemple #9
0
def mk_full_signature(code):
    o = []
    global_ctx = GlobalContext.get_global_context(code)

    for code in global_ctx._events:
        sig = EventSignature.from_declaration(code, custom_units=global_ctx._custom_units)
        o.append(sig.to_abi_dict())
    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(code, sigs=global_ctx._contracts, custom_units=global_ctx._custom_units)
        if not sig.private:
            default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx._custom_units)
            for s in default_sigs:
                o.append(s.to_abi_dict())
    return o
Exemple #10
0
def mk_single_method_identifier(code, global_ctx):
    identifiers = {}
    sig = FunctionSignature.from_definition(
        code,
        sigs=global_ctx._contracts,
        custom_structs=global_ctx._structs,
        constants=global_ctx._constants,
    )
    if not sig.private:
        default_sigs = generate_default_arg_sigs(code, global_ctx._contracts, global_ctx)
        for s in default_sigs:
            identifiers[s.sig] = hex(s.method_id)

    return identifiers
Exemple #11
0
def validate_public_function(code: ast.FunctionDef, sig: FunctionSignature,
                             global_ctx: GlobalContext) -> None:
    """ Validate public function definition. """

    # __init__ function may not have defaults.
    if sig.is_initializer() and sig.total_default_args > 0:
        raise FunctionDeclarationException(
            "__init__ function may not have default parameters.", code)

    # Check for duplicate variables with globals
    for arg in sig.args:
        if arg.name in global_ctx._globals:
            raise FunctionDeclarationException(
                "Variable name duplicated between "
                "function arguments and globals: " + arg.name, code)
Exemple #12
0
def mk_full_signature_from_json(abi):
    funcs = [func for func in abi if func["type"] == "function"]
    sigs = []

    for func in funcs:
        args = []
        returns = None
        for a in func["inputs"]:
            arg = vy_ast.arg(
                arg=a["name"],
                annotation=abi_type_to_ast(a["type"], 1048576),
                lineno=0,
                col_offset=0,
            )
            args.append(arg)

        if len(func["outputs"]) == 1:
            returns = abi_type_to_ast(func["outputs"][0]["type"], 1)
        elif len(func["outputs"]) > 1:
            returns = vy_ast.Tuple(elements=[
                abi_type_to_ast(a["type"], 1) for a in func["outputs"]
            ])

        decorator_list = [vy_ast.Name(id="external")]
        # Handle either constant/payable or stateMutability field
        if ("constant" in func
                and func["constant"]) or ("stateMutability" in func and
                                          func["stateMutability"] == "view"):
            decorator_list.append(vy_ast.Name(id="view"))
        if ("payable" in func
                and func["payable"]) or ("stateMutability" in func and
                                         func["stateMutability"] == "payable"):
            decorator_list.append(vy_ast.Name(id="payable"))

        sig = FunctionSignature.from_definition(
            code=vy_ast.FunctionDef(
                name=func["name"],
                args=vy_ast.arguments(args=args),
                decorator_list=decorator_list,
                returns=returns,
            ),
            custom_structs=dict(),
            constants=Constants(),
            is_from_json=True,
        )
        sigs.append(sig)
    return sigs
Exemple #13
0
def mk_full_signature(global_ctx, sig_formatter):
    o = []

    # Produce function signatures.
    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_structs=global_ctx._structs,
        )
        if not sig.internal:
            default_sigs = generate_default_arg_sigs(code,
                                                     global_ctx._contracts,
                                                     global_ctx)
            for s in default_sigs:
                o.append(sig_formatter(s))
    return o
Exemple #14
0
def mk_full_signature_from_json(abi):
    funcs = [func for func in abi if func['type'] == 'function']
    sigs = []

    for func in funcs:
        args = []
        returns = None
        for a in func['inputs']:
            arg = ast.arg(
                arg=a['name'],
                annotation=abi_type_to_ast(a['type']),
                lineno=0,
                col_offset=0
            )
            args.append(arg)

        if len(func['outputs']) == 1:
            returns = abi_type_to_ast(func['outputs'][0]['type'])
        elif len(func['outputs']) > 1:
            returns = ast.Tuple(
                elts=[
                    abi_type_to_ast(a['type'])
                    for a in func['outputs']
                ]
            )

        decorator_list = [ast.Name(id='public')]
        if func['constant']:
            decorator_list.append(ast.Name(id='constant'))
        if func['payable']:
            decorator_list.append(ast.Name(id='payable'))

        sig = FunctionSignature.from_definition(
            code=ast.FunctionDef(
                name=func['name'],
                args=ast.arguments(args=args),
                decorator_list=decorator_list,
                returns=returns,
            ),
            custom_units=set(),
            custom_structs=dict(),
            constants=Constants()
        )
        sigs.append(sig)
    return sigs
def mk_method_identifiers(code):
    o = []
    global_ctx = GlobalContext.get_global_context(parse(code))

    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_units=global_ctx._custom_units)
        if not sig.private:
            default_sigs = generate_default_arg_sigs(code,
                                                     global_ctx._contracts,
                                                     global_ctx._custom_units)
            for s in default_sigs:
                o.append(
                    s.get_method_identifier(global_ctx._contracts,
                                            global_ctx._custom_units))

    return dict(o)
Exemple #16
0
def parse_external_interfaces(external_interfaces, global_ctx):
    for _interfacename in global_ctx._contracts:
        _interface_defs = global_ctx._contracts[_interfacename]
        _defnames = [_def.name for _def in _interface_defs]
        interface = {}
        if len(set(_defnames)) < len(_interface_defs):
            raise FunctionDeclarationException(
                "Duplicate function name: "
                f"{[name for name in _defnames if _defnames.count(name) > 1][0]}"
            )

        for _def in _interface_defs:
            constant = False
            # test for valid call type keyword.
            if (len(_def.body) == 1 and isinstance(_def.body[0], vy_ast.Expr)
                    and isinstance(_def.body[0].value, vy_ast.Name)
                    # NOTE: Can't import enums here because of circular import
                    and _def.body[0].value.id
                    in ("pure", "view", "nonpayable", "payable")):
                constant = True if _def.body[0].value.id in ("view",
                                                             "pure") else False
            else:
                raise StructureException(
                    "state mutability of call type must be specified", _def)
            # Recognizes already-defined structs
            sig = FunctionSignature.from_definition(
                _def,
                interface_def=True,
                constant_override=constant,
                custom_structs=global_ctx._structs,
                constants=global_ctx._constants,
            )
            interface[sig.name] = sig
        external_interfaces[_interfacename] = interface

    for interface_name, interface in global_ctx._interfaces.items():
        external_interfaces[interface_name] = {
            sig.name: sig
            for sig in interface if isinstance(sig, FunctionSignature)
        }

    return external_interfaces
Exemple #17
0
def parse_other_functions(o, otherfuncs, _globals, sigs, external_contracts, origcode, _custom_units, fallback_function, runtime_only):
    sub = ['seq', initializer_lll]
    add_gas = initializer_lll.gas
    for _def in otherfuncs:
        sub.append(parse_func(_def, _globals, {**{'self': sigs}, **external_contracts}, origcode, _custom_units))  # noqa E999
        sub[-1].total_gas += add_gas
        add_gas += 30
        sig = FunctionSignature.from_definition(_def, external_contracts, custom_units=_custom_units)
        sig.gas = sub[-1].total_gas
        sigs[sig.name] = sig
    # Add fallback function
    if fallback_function:
        fallback_func = parse_func(fallback_function[0], _globals, {**{'self': sigs}, **external_contracts}, origcode, _custom_units)
        sub.append(fallback_func)
    else:
        sub.append(LLLnode.from_list(['revert', 0, 0], typ=None, annotation='Default function'))
    if runtime_only:
        return sub
    else:
        o.append(['return', 0, ['lll', sub, 0]])
        return o
Exemple #18
0
def _call_lookup_specs(stmt_expr, context):
    from vyper.parser.expr import Expr

    method_name = stmt_expr.func.attr

    if len(stmt_expr.keywords):
        raise TypeMismatch(
            "Cannot use keyword arguments in calls to functions via 'self'",
            stmt_expr,
        )
    expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args]

    sig = FunctionSignature.lookup_sig(
        context.sigs,
        method_name,
        expr_args,
        stmt_expr,
        context,
    )

    return method_name, expr_args, sig
Exemple #19
0
def parse_external_contracts(external_contracts, global_ctx):
    for _contractname in global_ctx._contracts:
        _contract_defs = global_ctx._contracts[_contractname]
        _defnames = [_def.name for _def in _contract_defs]
        contract = {}
        if len(set(_defnames)) < len(_contract_defs):
            raise FunctionDeclarationException(
                "Duplicate function name: "
                f"{[name for name in _defnames if _defnames.count(name) > 1][0]}"
            )

        for _def in _contract_defs:
            constant = False
            # test for valid call type keyword.
            if len(_def.body) == 1 and \
               isinstance(_def.body[0], ast.Expr) and \
               isinstance(_def.body[0].value, ast.Name) and \
               _def.body[0].value.id in ('modifying', 'constant'):
                constant = True if _def.body[
                    0].value.id == 'constant' else False
            else:
                raise StructureException(
                    'constant or modifying call type must be specified', _def)
            # Recognizes already-defined structs
            sig = FunctionSignature.from_definition(
                _def,
                contract_def=True,
                constant_override=constant,
                custom_structs=global_ctx._structs,
                constants=global_ctx._constants)
            contract[sig.name] = sig
        external_contracts[_contractname] = contract

    for interface_name, interface in global_ctx._interfaces.items():
        external_contracts[interface_name] = {
            sig.name: sig
            for sig in interface if isinstance(sig, FunctionSignature)
        }

    return external_contracts
Exemple #20
0
def mk_method_identifiers(code, interface_codes=None):
    from vyper.parser.parser import parse_to_ast
    o = {}
    global_ctx = GlobalContext.get_global_context(
        parse_to_ast(code),
        interface_codes=interface_codes,
    )

    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_units=global_ctx._custom_units,
            constants=global_ctx._constants,
        )
        if not sig.private:
            default_sigs = generate_default_arg_sigs(code,
                                                     global_ctx._contracts,
                                                     global_ctx)
            for s in default_sigs:
                o[s.sig] = hex(s.method_id)

    return o
Exemple #21
0
def parse_tree_to_lll(global_ctx: GlobalContext) -> Tuple[LLLnode, LLLnode]:
    _names_def = [_def.name for _def in global_ctx._defs]
    # Checks for duplicate function names
    if len(set(_names_def)) < len(_names_def):
        raise FunctionDeclarationException(
            "Duplicate function name: "
            f"{[name for name in _names_def if _names_def.count(name) > 1][0]}"
        )
    _names_events = [_event.name for _event in global_ctx._events]
    # Checks for duplicate event names
    if len(set(_names_events)) < len(_names_events):
        raise EventDeclarationException(f"""Duplicate event name:
            {[name for name in _names_events if _names_events.count(name) > 1][0]}"""
                                        )
    # Initialization function
    initfunc = [_def for _def in global_ctx._defs if is_initializer(_def)]
    # Default function
    defaultfunc = [_def for _def in global_ctx._defs if is_default_func(_def)]
    # Regular functions
    otherfuncs = [
        _def for _def in global_ctx._defs
        if not is_initializer(_def) and not is_default_func(_def)
    ]

    # check if any functions in the contract are payable - if not, we do a single
    # ASSERT CALLVALUE ISZERO at the start of the bytecode rather than at the start
    # of each function
    is_contract_payable = next(
        (True for i in global_ctx._defs if FunctionSignature.from_definition(
            i, custom_structs=global_ctx._structs).mutability == "payable"),
        False,
    )

    sigs: dict = {}
    external_interfaces: dict = {}
    # Create the main statement
    o = ["seq"]
    if global_ctx._contracts or global_ctx._interfaces:
        external_interfaces = parse_external_interfaces(
            external_interfaces, global_ctx)
    # If there is an init func...
    if initfunc:
        o.append(init_func_init_lll())
        o.append(
            parse_function(
                initfunc[0],
                {
                    **{
                        "self": sigs
                    },
                    **external_interfaces
                },
                global_ctx,
                False,
            ))

    # If there are regular functions...
    if otherfuncs or defaultfunc:
        o, runtime = parse_other_functions(
            o,
            otherfuncs,
            sigs,
            external_interfaces,
            global_ctx,
            defaultfunc,
            is_contract_payable,
        )
    else:
        runtime = o.copy()

    if not is_contract_payable:
        # if no functions in the contract are payable, assert that callvalue is
        # zero at the beginning of the bytecode
        runtime.insert(1, ["assert", ["iszero", "callvalue"]])

    return LLLnode.from_list(o, typ=None), LLLnode.from_list(runtime, typ=None)
Exemple #22
0
def parse_func(code, sigs, origcode, global_ctx, _vars=None):
    if _vars is None:
        _vars = {}
    sig = FunctionSignature.from_definition(
        code,
        sigs=sigs,
        custom_units=global_ctx._custom_units,
        custom_structs=global_ctx._structs,
        constants=global_ctx._constants)
    # Get base args for function.
    total_default_args = len(code.args.defaults)
    base_args = sig.args[:
                         -total_default_args] if total_default_args > 0 else sig.args
    default_args = code.args.args[-total_default_args:]
    default_values = dict(
        zip([arg.arg for arg in default_args], code.args.defaults))
    # __init__ function may not have defaults.
    if sig.name == '__init__' and total_default_args > 0:
        raise FunctionDeclarationException(
            "__init__ function may not have default parameters.")
    # Check for duplicate variables with globals
    for arg in sig.args:
        if arg.name in global_ctx._globals:
            raise FunctionDeclarationException(
                "Variable name duplicated between function arguments and globals: "
                + arg.name)

    nonreentrant_pre = [['pass']]
    nonreentrant_post = [['pass']]
    if sig.nonreentrant_key:
        nkey = global_ctx.get_nonrentrant_counter(sig.nonreentrant_key)
        nonreentrant_pre = [[
            'seq', ['assert', ['iszero', ['sload', nkey]]],
            ['sstore', nkey, 1]
        ]]
        nonreentrant_post = [['sstore', nkey, 0]]

    # Create a local (per function) context.
    context = Context(
        vars=_vars,
        global_ctx=global_ctx,
        sigs=sigs,
        return_type=sig.output_type,
        constancy=Constancy.Constant if sig.const else Constancy.Mutable,
        is_payable=sig.payable,
        origcode=origcode,
        is_private=sig.private,
        method_id=sig.method_id)

    # Copy calldata to memory for fixed-size arguments
    max_copy_size = sum([
        32 if isinstance(arg.typ, ByteArrayLike) else
        get_size_of_type(arg.typ) * 32 for arg in sig.args
    ])
    base_copy_size = sum([
        32 if isinstance(arg.typ, ByteArrayLike) else
        get_size_of_type(arg.typ) * 32 for arg in base_args
    ])
    context.next_mem += max_copy_size

    clampers = []

    # Create callback_ptr, this stores a destination in the bytecode for a private
    # function to jump to after a function has executed.
    _post_callback_ptr = "{}_{}_post_callback_ptr".format(
        sig.name, sig.method_id)
    if sig.private:
        context.callback_ptr = context.new_placeholder(typ=BaseType('uint256'))
        clampers.append(
            LLLnode.from_list(
                ['mstore', context.callback_ptr, 'pass'],
                annotation='pop callback pointer',
            ))
        if total_default_args > 0:
            clampers.append(['label', _post_callback_ptr])

    # private functions without return types need to jump back to
    # the calling function, as there is no return statement to handle the
    # jump.
    stop_func = [['stop']]
    if sig.output_type is None and sig.private:
        stop_func = [['jump', ['mload', context.callback_ptr]]]

    if not len(base_args):
        copier = 'pass'
    elif sig.name == '__init__':
        copier = [
            'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen',
            base_copy_size
        ]
    else:
        copier = get_arg_copier(sig=sig,
                                total_size=base_copy_size,
                                memory_dest=MemoryPositions.RESERVED_MEMORY)
    clampers.append(copier)

    # Add asserts for payable and internal
    # private never gets payable check.
    if not sig.payable and not sig.private:
        clampers.append(['assert', ['iszero', 'callvalue']])

    # Fill variable positions
    for i, arg in enumerate(sig.args):
        if i < len(base_args) and not sig.private:
            clampers.append(
                make_clamper(
                    arg.pos,
                    context.next_mem,
                    arg.typ,
                    sig.name == '__init__',
                ))
        if isinstance(arg.typ, ByteArrayLike):
            context.vars[arg.name] = VariableRecord(arg.name, context.next_mem,
                                                    arg.typ, False)
            context.next_mem += 32 * get_size_of_type(arg.typ)
        else:
            context.vars[arg.name] = VariableRecord(
                arg.name,
                MemoryPositions.RESERVED_MEMORY + arg.pos,
                arg.typ,
                False,
            )

    # Private function copiers. No clamping for private functions.
    dyn_variable_names = [
        a.name for a in base_args if isinstance(a.typ, ByteArrayLike)
    ]
    if sig.private and dyn_variable_names:
        i_placeholder = context.new_placeholder(typ=BaseType('uint256'))
        unpackers = []
        for idx, var_name in enumerate(dyn_variable_names):
            var = context.vars[var_name]
            ident = "_load_args_%d_dynarg%d" % (sig.method_id, idx)
            o = make_unpacker(ident=ident,
                              i_placeholder=i_placeholder,
                              begin_pos=var.pos)
            unpackers.append(o)

        if not unpackers:
            unpackers = ['pass']

        clampers.append(
            LLLnode.from_list(
                # [0] to complete full overarching 'seq' statement, see private_label.
                ['seq_unchecked'] + unpackers + [0],
                typ=None,
                annotation='dynamic unpacker',
                pos=getpos(code),
            ))

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == '__init__':
        o = LLLnode.from_list(
            ['seq'] + clampers + [parse_body(code.body, context)],
            pos=getpos(code),
        )
    elif is_default_func(sig):
        if len(sig.args) > 0:
            raise FunctionDeclarationException(
                'Default function may not receive any arguments.', code)
        if sig.private:
            raise FunctionDeclarationException(
                'Default function may only be public.',
                code,
            )
        o = LLLnode.from_list(
            ['seq'] + clampers + [parse_body(code.body, context)],
            pos=getpos(code),
        )
    else:

        if total_default_args > 0:  # Function with default parameters.
            function_routine = "{}_{}".format(sig.name, sig.method_id)
            default_sigs = generate_default_arg_sigs(code, sigs, global_ctx)
            sig_chain = ['seq']

            for default_sig in default_sigs:
                sig_compare, private_label = get_sig_statements(
                    default_sig, getpos(code))

                # Populate unset default variables
                populate_arg_count = len(sig.args) - len(default_sig.args)
                set_defaults = []
                if populate_arg_count > 0:
                    current_sig_arg_names = {x.name for x in default_sig.args}
                    missing_arg_names = [
                        arg.arg for arg in default_args
                        if arg.arg not in current_sig_arg_names
                    ]
                    for arg_name in missing_arg_names:
                        value = Expr(default_values[arg_name],
                                     context).lll_node
                        var = context.vars[arg_name]
                        left = LLLnode.from_list(var.pos,
                                                 typ=var.typ,
                                                 location='memory',
                                                 pos=getpos(code),
                                                 mutable=var.mutable)
                        set_defaults.append(
                            make_setter(left,
                                        value,
                                        'memory',
                                        pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in base_args}
                if sig.private:
                    # Load all variables in default section, if private,
                    # because the stack is a linear pipe.
                    copier_arg_count = len(default_sig.args)
                    copier_arg_names = current_sig_arg_names
                else:
                    copier_arg_count = len(default_sig.args) - len(base_args)
                    copier_arg_names = current_sig_arg_names - base_arg_names
                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)
                    # Copy set default parameters from calldata
                    dynamics = []
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]
                        if sig.private:
                            _offset = calldata_offset
                            if isinstance(var.typ, ByteArrayLike):
                                _size = 32
                                dynamics.append(var.pos)
                            else:
                                _size = var.size * 32
                            default_copiers.append(
                                get_arg_copier(
                                    sig=sig,
                                    memory_dest=var.pos,
                                    total_size=_size,
                                    offset=_offset,
                                ))
                        else:
                            # Add clampers.
                            default_copiers.append(
                                make_clamper(
                                    calldata_offset - 4,
                                    var.pos,
                                    var.typ,
                                ))
                            # Add copying code.
                            if isinstance(var.typ, ByteArrayLike):
                                _offset = [
                                    'add', 4,
                                    ['calldataload', calldata_offset]
                                ]
                            else:
                                _offset = calldata_offset
                            default_copiers.append(
                                get_arg_copier(
                                    sig=sig,
                                    memory_dest=var.pos,
                                    total_size=var.size * 32,
                                    offset=_offset,
                                ))

                    # Unpack byte array if necessary.
                    if dynamics:
                        i_placeholder = context.new_placeholder(
                            typ=BaseType('uint256'))
                        for idx, var_pos in enumerate(dynamics):
                            ident = 'unpack_default_sig_dyn_%d_arg%d' % (
                                default_sig.method_id, idx)
                            default_copiers.append(
                                make_unpacker(
                                    ident=ident,
                                    i_placeholder=i_placeholder,
                                    begin_pos=var_pos,
                                ))
                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    'if', sig_compare,
                    [
                        'seq', private_label, ['pass'] if not sig.private else
                        LLLnode.from_list([
                            'mstore',
                            context.callback_ptr,
                            'pass',
                        ],
                                          annotation='pop callback pointer',
                                          pos=getpos(code)),
                        ['seq'] + set_defaults if set_defaults else ['pass'],
                        ['seq_unchecked'] +
                        default_copiers if default_copiers else ['pass'],
                        [
                            'goto', _post_callback_ptr
                            if sig.private else function_routine
                        ]
                    ]
                ])

            # With private functions all variable loading occurs in the default
            # function sub routine.
            if sig.private:
                _clampers = [['label', _post_callback_ptr]]
            else:
                _clampers = clampers

            # Function with default parameters.
            o = LLLnode.from_list(
                [
                    'seq',
                    sig_chain,
                    [
                        'if',
                        0,  # can only be jumped into
                        [
                            'seq', ['label', function_routine]
                            if not sig.private else ['pass'],
                            ['seq'] + nonreentrant_pre + _clampers +
                            [parse_body(c, context) for c in code.body] +
                            nonreentrant_post + stop_func
                        ],
                    ],
                ],
                typ=None,
                pos=getpos(code))

        else:
            # Function without default parameters.
            sig_compare, private_label = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list([
                'if', sig_compare,
                ['seq'] + [private_label] + nonreentrant_pre + clampers +
                [parse_body(c, context)
                 for c in code.body] + nonreentrant_post + stop_func
            ],
                                  typ=None,
                                  pos=getpos(code))

    # Check for at leasts one return statement if necessary.
    if context.return_type and context.function_return_count == 0:
        raise FunctionDeclarationException(
            "Missing return statement in function '%s' " % sig.name, code)

    o.context = context
    o.total_gas = o.gas + calc_mem_gas(o.context.next_mem)
    o.func_name = sig.name
    return o
Exemple #23
0
def parse_func(code, sigs, origcode, global_ctx, _vars=None):
    if _vars is None:
        _vars = {}
    sig = FunctionSignature.from_definition(
        code, sigs=sigs, custom_units=global_ctx._custom_units)
    # Get base args for function.
    total_default_args = len(code.args.defaults)
    base_args = sig.args[:
                         -total_default_args] if total_default_args > 0 else sig.args
    default_args = code.args.args[-total_default_args:]
    default_values = dict(
        zip([arg.arg for arg in default_args], code.args.defaults))
    # __init__ function may not have defaults.
    if sig.name == '__init__' and total_default_args > 0:
        raise FunctionDeclarationException(
            "__init__ function may not have default parameters.")
    # Check for duplicate variables with globals
    for arg in sig.args:
        if arg.name in global_ctx._globals:
            raise FunctionDeclarationException(
                "Variable name duplicated between function arguments and globals: "
                + arg.name)

    # Create a context
    context = Context(vars=_vars,
                      globals=global_ctx._globals,
                      sigs=sigs,
                      return_type=sig.output_type,
                      is_constant=sig.const,
                      is_payable=sig.payable,
                      origcode=origcode,
                      custom_units=global_ctx._custom_units)
    # Copy calldata to memory for fixed-size arguments
    max_copy_size = sum([
        32 if isinstance(arg.typ, ByteArrayType) else
        get_size_of_type(arg.typ) * 32 for arg in sig.args
    ])
    base_copy_size = sum([
        32 if isinstance(arg.typ, ByteArrayType) else
        get_size_of_type(arg.typ) * 32 for arg in base_args
    ])
    context.next_mem += max_copy_size

    if not len(base_args):
        copier = 'pass'
    elif sig.name == '__init__':
        copier = [
            'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen',
            base_copy_size
        ]
    else:
        copier = [
            'calldatacopy', MemoryPositions.RESERVED_MEMORY, 4, base_copy_size
        ]
    clampers = [copier]
    # Add asserts for payable and internal
    if not sig.payable:
        clampers.append(['assert', ['iszero', 'callvalue']])
    if sig.private:
        clampers.append(['assert', ['eq', 'caller', 'address']])

    # Fill variable positions
    for i, arg in enumerate(sig.args):
        if i < len(base_args):
            clampers.append(
                make_clamper(arg.pos, context.next_mem, arg.typ,
                             sig.name == '__init__'))
        if isinstance(arg.typ, ByteArrayType):
            context.vars[arg.name] = VariableRecord(arg.name, context.next_mem,
                                                    arg.typ, False)
            context.next_mem += 32 * get_size_of_type(arg.typ)
        else:
            context.vars[arg.name] = VariableRecord(
                arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ,
                False)

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == '__init__':
        o = LLLnode.from_list(['seq'] + clampers +
                              [parse_body(code.body, context)],
                              pos=getpos(code))
    elif is_default_func(sig):
        if len(sig.args) > 0:
            raise FunctionDeclarationException(
                'Default function may not receive any arguments.', code)
        if sig.private:
            raise FunctionDeclarationException(
                'Default function may only be public.', code)
        o = LLLnode.from_list(['seq'] + clampers +
                              [parse_body(code.body, context)],
                              pos=getpos(code))
    else:
        # Handle default args if present.
        function_routine = "{}_{}".format(sig.name, sig.method_id)
        if total_default_args > 0:
            default_sigs = generate_default_arg_sigs(code, sigs,
                                                     global_ctx._custom_units)
            sig_chain = ['seq']

            for default_sig_idx, default_sig in enumerate(default_sigs):
                method_id_node = LLLnode.from_list(default_sig.method_id,
                                                   pos=getpos(code),
                                                   annotation='%s' %
                                                   default_sig.sig)

                # Populate unset default variables
                populate_arg_count = len(sig.args) - len(default_sig.args)
                set_defaults = []
                if populate_arg_count > 0:
                    current_sig_arg_names = {x.name for x in default_sig.args}
                    missing_arg_names = [
                        arg.arg for arg in default_args
                        if arg.arg not in current_sig_arg_names
                    ]
                    for arg_name in missing_arg_names:
                        value = Expr(default_values[arg_name],
                                     context).lll_node
                        var = context.vars[arg_name]
                        left = LLLnode.from_list(var.pos,
                                                 typ=var.typ,
                                                 location='memory',
                                                 pos=getpos(code),
                                                 mutable=var.mutable)
                        set_defaults.append(
                            make_setter(left,
                                        value,
                                        'memory',
                                        pos=getpos(code)))
                # Variables to be populated from calldata
                copier_arg_count = len(default_sig.args) - len(base_args)
                default_copiers = []
                if copier_arg_count > 0:
                    current_sig_arg_names = {x.name for x in default_sig.args}
                    base_arg_names = {arg.name for arg in base_args}
                    copier_arg_names = current_sig_arg_names - base_arg_names

                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += 32 if isinstance(
                            arg.typ,
                            ByteArrayType) else get_size_of_type(arg.typ) * 32
                    # Copy set default parameters from calldata
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]
                        # Add clampers.
                        default_copiers.append(
                            make_clamper(calldata_offset - 4, var.pos,
                                         var.typ))
                        # Add copying code.
                        if isinstance(var.typ, ByteArrayType):
                            default_copiers.append([
                                'calldatacopy', var.pos,
                                ['add', 4, ['calldataload', calldata_offset]],
                                var.size * 32
                            ])
                        else:
                            default_copiers.append([
                                'calldatacopy', var.pos, calldata_offset,
                                var.size * 32
                            ])

                sig_chain.append([
                    'if', ['eq', ['mload', 0], method_id_node],
                    [
                        'seq',
                        ['seq'] + set_defaults if set_defaults else ['pass'],
                        ['seq'] +
                        default_copiers if default_copiers else ['pass'],
                        ['goto', function_routine]
                    ]
                ])

            o = LLLnode.from_list(
                [
                    'seq',
                    sig_chain,
                    [
                        'if',
                        0,  # can only be jumped into
                        [
                            'seq', ['label', function_routine
                                    ], ['seq'] + clampers +
                            [parse_body(c, context)
                             for c in code.body] + ['stop']
                        ]
                    ]
                ],
                typ=None,
                pos=getpos(code))

        else:
            # Function without default parameters.
            method_id_node = LLLnode.from_list(sig.method_id,
                                               pos=getpos(code),
                                               annotation='%s' % sig.sig)
            o = LLLnode.from_list([
                'if', ['eq', ['mload', 0], method_id_node], ['seq'] +
                clampers + [parse_body(c, context)
                            for c in code.body] + ['stop']
            ],
                                  typ=None,
                                  pos=getpos(code))

    # Check for at leasts one return statement if necessary.
    if context.return_type and context.function_return_count == 0:
        raise FunctionDeclarationException(
            "Missing return statement in function '%s' " % sig.name, code)

    o.context = context
    o.total_gas = o.gas + calc_mem_gas(o.context.next_mem)
    o.func_name = sig.name
    return o
Exemple #24
0
    def call(self):
        from .parser import (
            external_contract_call,
            pack_arguments,
        )
        from vyper.functions import (
            dispatch_table,
        )

        if isinstance(self.expr.func, ast.Name):
            function_name = self.expr.func.id
            if function_name in dispatch_table:
                return dispatch_table[function_name](self.expr, self.context)
            else:
                err_msg = "Not a top-level function: {}".format(function_name)
                if function_name in [x.split('(')[0] for x, _ in self.context.sigs['self'].items()]:
                    err_msg += ". Did you mean self.{}?".format(function_name)
                raise StructureException(err_msg, self.expr)
        elif isinstance(self.expr.func, ast.Attribute) and isinstance(self.expr.func.value, ast.Name) and self.expr.func.value.id == "self":
            expr_args = [Expr(arg, self.context).lll_node for arg in self.expr.args]
            method_name = self.expr.func.attr
            sig = FunctionSignature.lookup_sig(self.context.sigs, method_name, expr_args, self.expr, self.context)
            if self.context.is_constant and not sig.const:
                raise ConstancyViolationException(
                    "May not call non-constant function '%s' within a constant function." % (method_name),
                    getpos(self.expr)
                )
            add_gas = sig.gas  # gas of call
            inargs, inargsize = pack_arguments(sig, expr_args, self.context, pos=getpos(self.expr))
            output_placeholder = self.context.new_placeholder(typ=sig.output_type)
            multi_arg = []
            if isinstance(sig.output_type, BaseType):
                returner = output_placeholder
            elif isinstance(sig.output_type, ByteArrayType):
                returner = output_placeholder + 32
            elif isinstance(sig.output_type, TupleType):
                returner = output_placeholder
            else:
                raise TypeMismatchException("Invalid output type: %r" % sig.output_type, self.expr)

            o = LLLnode.from_list(multi_arg +
                    ['seq',
                        ['assert', ['call', ['gas'], ['address'], 0,
                                        inargs, inargsize,
                                        output_placeholder, get_size_of_type(sig.output_type) * 32]], returner],
                typ=sig.output_type, location='memory',
                pos=getpos(self.expr), add_gas_estimate=add_gas, annotation='Internal Call: %s' % method_name)
            o.gas += sig.gas
            return o
        elif isinstance(self.expr.func, ast.Attribute) and isinstance(self.expr.func.value, ast.Call):
            contract_name = self.expr.func.value.func.id
            contract_address = Expr.parse_value_expr(self.expr.func.value.args[0], self.context)
            value, gas = self._get_external_contract_keywords()
            return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas)
        elif isinstance(self.expr.func.value, ast.Attribute) and self.expr.func.value.attr in self.context.sigs:
            contract_name = self.expr.func.value.attr
            var = self.context.globals[self.expr.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.expr), annotation='self.' + self.expr.func.value.attr))
            value, gas = self._get_external_contract_keywords()
            return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas)
        elif isinstance(self.expr.func.value, ast.Attribute) and self.expr.func.value.attr in self.context.globals:
            contract_name = self.context.globals[self.expr.func.value.attr].typ.unit
            var = self.context.globals[self.expr.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.expr), annotation='self.' + self.expr.func.value.attr))
            value, gas = self._get_external_contract_keywords()
            return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas)
        else:
            raise StructureException("Unsupported operator: %r" % ast.dump(self.expr), self.expr)
Exemple #25
0
def parse_public_function(code: ast.FunctionDef, sig: FunctionSignature,
                          context: Context) -> LLLnode:
    """
    Parse a public function (FuncDef), and produce full function body.

    :param sig: the FuntionSignature
    :param code: ast of function
    :return: full sig compare & function body
    """

    validate_public_function(code, sig, context.global_ctx)

    # Get nonreentrant lock
    nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(
        sig, context.global_ctx)

    clampers = []

    # Generate copiers
    copier: List[Any] = ['pass']
    if not len(sig.base_args):
        copier = ['pass']
    elif sig.name == '__init__':
        copier = [
            'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen',
            sig.base_copy_size
        ]
        context.memory_allocator.increase_memory(sig.max_copy_size)
    clampers.append(copier)

    # Add asserts for payable and internal
    if not sig.payable:
        clampers.append(['assert', ['iszero', 'callvalue']])

    # Fill variable positions
    default_args_start_pos = len(sig.base_args)
    for i, arg in enumerate(sig.args):
        if i < len(sig.base_args):
            clampers.append(
                make_arg_clamper(
                    arg.pos,
                    context.memory_allocator.get_next_memory_position(),
                    arg.typ,
                    sig.name == '__init__',
                ))
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos, _ = context.memory_allocator.increase_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            if sig.name == '__init__':
                context.vars[arg.name] = VariableRecord(
                    arg.name,
                    MemoryPositions.RESERVED_MEMORY + arg.pos,
                    arg.typ,
                    False,
                )
            elif i >= default_args_start_pos:  # default args need to be allocated in memory.
                default_arg_pos, _ = context.memory_allocator.increase_memory(
                    32)
                context.vars[arg.name] = VariableRecord(
                    name=arg.name,
                    pos=default_arg_pos,
                    typ=arg.typ,
                    mutable=False,
                )
            else:
                context.vars[arg.name] = VariableRecord(name=arg.name,
                                                        pos=4 + arg.pos,
                                                        typ=arg.typ,
                                                        mutable=False,
                                                        location='calldata')

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == '__init__':
        o = LLLnode.from_list(
            ['seq'] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is default function.
    elif sig.is_default_func():
        if len(sig.args) > 0:
            raise FunctionDeclarationException(
                'Default function may not receive any arguments.', code)
        o = LLLnode.from_list(
            ['seq'] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is a normal function.
    else:
        # Function with default parameters.
        if sig.total_default_args > 0:
            function_routine = f"{sig.name}_{sig.method_id}"
            default_sigs = sig_utils.generate_default_arg_sigs(
                code, context.sigs, context.global_ctx)
            sig_chain: List[Any] = ['seq']

            for default_sig in default_sigs:
                sig_compare, _ = get_sig_statements(default_sig, getpos(code))

                # Populate unset default variables
                set_defaults = []
                for arg_name in get_default_names_to_set(sig, default_sig):
                    value = Expr(sig.default_values[arg_name],
                                 context).lll_node
                    var = context.vars[arg_name]
                    left = LLLnode.from_list(var.pos,
                                             typ=var.typ,
                                             location='memory',
                                             pos=getpos(code),
                                             mutable=var.mutable)
                    set_defaults.append(
                        make_setter(left, value, 'memory', pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in sig.base_args}
                copier_arg_count = len(default_sig.args) - len(sig.base_args)
                copier_arg_names = list(current_sig_arg_names - base_arg_names)

                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers: List[Any] = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)

                    # Copy default parameters from calldata.
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]

                        # Add clampers.
                        default_copiers.append(
                            make_arg_clamper(
                                calldata_offset - 4,
                                var.pos,
                                var.typ,
                            ))
                        # Add copying code.
                        _offset: Union[int, List[Any]] = calldata_offset
                        if isinstance(var.typ, ByteArrayLike):
                            _offset = [
                                'add', 4, ['calldataload', calldata_offset]
                            ]
                        default_copiers.append(
                            get_public_arg_copier(
                                memory_dest=var.pos,
                                total_size=var.size * 32,
                                offset=_offset,
                            ))

                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    'if', sig_compare,
                    [
                        'seq',
                        ['seq'] + set_defaults if set_defaults else ['pass'],
                        ['seq_unchecked'] +
                        default_copiers if default_copiers else ['pass'],
                        ['goto', function_routine]
                    ]
                ])

            # Function with default parameters.
            o = LLLnode.from_list(
                [
                    'seq',
                    sig_chain,
                    [
                        'if',
                        0,  # can only be jumped into
                        [
                            'seq', ['label', function_routine
                                    ], ['seq'] + nonreentrant_pre + clampers +
                            [parse_body(c, context) for c in code.body] +
                            nonreentrant_post + [['stop']]
                        ],
                    ],
                ],
                typ=None,
                pos=getpos(code))

        else:
            # Function without default parameters.
            sig_compare, _ = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list([
                'if', sig_compare, ['seq'] + nonreentrant_pre + clampers +
                [parse_body(c, context)
                 for c in code.body] + nonreentrant_post + [['stop']]
            ],
                                  typ=None,
                                  pos=getpos(code))
    return o
Exemple #26
0
    def call(self):
        from .parser import (
            pack_arguments,
            pack_logging_data,
            pack_logging_topics,
            external_contract_call,
        )
        if isinstance(self.stmt.func, ast.Name):
            if self.stmt.func.id in stmt_dispatch_table:
                return stmt_dispatch_table[self.stmt.func.id](self.stmt, self.context)
            elif self.stmt.func.id in dispatch_table:
                raise StructureException("Function {} can not be called without being used.".format(self.stmt.func.id), self.stmt)
            else:
                raise StructureException("Unknown function: '{}'.".format(self.stmt.func.id), self.stmt)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Name) and self.stmt.func.value.id == "self":
            method_name = self.stmt.func.attr
            expr_args = [Expr(arg, self.context).lll_node for arg in self.stmt.args]
            # full_sig = FunctionSignature.get_full_sig(method_name, expr_args, self.context.sigs, self.context.custom_units)
            sig = FunctionSignature.lookup_sig(self.context.sigs, method_name, expr_args, self.stmt, self.context)
            if self.context.is_constant and not sig.const:
                raise ConstancyViolationException(
                    "May not call non-constant function '%s' within a constant function." % (sig.sig)
                )
            add_gas = self.context.sigs['self'][sig.sig].gas
            inargs, inargsize = pack_arguments(sig,
                                                expr_args,
                                                self.context, pos=getpos(self.stmt))
            return LLLnode.from_list(['assert', ['call', ['gas'], ['address'], 0, inargs, inargsize, 0, 0]],
                                        typ=None, pos=getpos(self.stmt), add_gas_estimate=add_gas, annotation='Internal Call: %s' % sig.sig)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Call):
            contract_name = self.stmt.func.value.func.id
            contract_address = Expr.parse_value_expr(self.stmt.func.value.args[0], self.context)
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.sigs:
            contract_name = self.stmt.func.value.attr
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr))
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.globals:
            contract_name = self.context.globals[self.stmt.func.value.attr].typ.unit
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr))
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func, ast.Attribute) and self.stmt.func.value.id == 'log':
            if self.stmt.func.attr not in self.context.sigs['self']:
                raise EventDeclarationException("Event not declared yet: %s" % self.stmt.func.attr)
            event = self.context.sigs['self'][self.stmt.func.attr]
            if len(event.indexed_list) != len(self.stmt.args):
                raise EventDeclarationException("%s received %s arguments but expected %s" % (event.name, len(self.stmt.args), len(event.indexed_list)))
            expected_topics, topics = [], []
            expected_data, data = [], []
            for pos, is_indexed in enumerate(event.indexed_list):
                if is_indexed:
                    expected_topics.append(event.args[pos])
                    topics.append(self.stmt.args[pos])
                else:
                    expected_data.append(event.args[pos])
                    data.append(self.stmt.args[pos])
            topics = pack_logging_topics(event.event_id, topics, expected_topics, self.context, pos=getpos(self.stmt))
            inargs, inargsize, inargsize_node, inarg_start = pack_logging_data(expected_data, data, self.context, pos=getpos(self.stmt))

            if inargsize_node is None:
                sz = inargsize
            else:
                sz = ['mload', inargsize_node]

            return LLLnode.from_list(['seq', inargs,
                LLLnode.from_list(["log" + str(len(topics)), inarg_start, sz] + topics, add_gas_estimate=inargsize * 10)], typ=None, pos=getpos(self.stmt))
        else:
            raise StructureException("Unsupported operator: %r" % ast.dump(self.stmt), self.stmt)
Exemple #27
0
def make_call(stmt_expr, context):
    # ** Internal Call **
    # Steps:
    # (x) push current local variables
    # (x) push arguments
    # (x) push jumpdest (callback ptr)
    # (x) jump to label
    # (x) pop return values
    # (x) pop local variables

    pop_local_vars = []
    push_local_vars = []
    pop_return_values = []
    push_args = []
    method_name = stmt_expr.func.attr

    from vyper.parser.expr import parse_sequence

    pre_init, expr_args = parse_sequence(stmt_expr, stmt_expr.args, context)
    sig = FunctionSignature.lookup_sig(
        context.sigs,
        method_name,
        expr_args,
        stmt_expr,
        context,
    )

    if context.is_constant() and sig.mutability not in ("view", "pure"):
        raise StateAccessViolation(
            f"May not call state modifying function "
            f"'{method_name}' within {context.pp_constancy()}.",
            getpos(stmt_expr),
        )

    if not sig.internal:
        raise StructureException("Cannot call external functions via 'self'",
                                 stmt_expr)

    # Push local variables.
    var_slots = [(v.pos, v.size) for name, v in context.vars.items()
                 if v.location == "memory"]
    if var_slots:
        var_slots.sort(key=lambda x: x[0])

        if len(var_slots) > 10:
            # if memory is large enough, push and pop it via iteration
            mem_from, mem_to = var_slots[0][
                0], var_slots[-1][0] + var_slots[-1][1] * 32
            i_placeholder = context.new_internal_variable(BaseType("uint256"))
            local_save_ident = f"_{stmt_expr.lineno}_{stmt_expr.col_offset}"
            push_loop_label = "save_locals_start" + local_save_ident
            pop_loop_label = "restore_locals_start" + local_save_ident
            push_local_vars = [
                ["mstore", i_placeholder, mem_from],
                ["label", push_loop_label],
                ["mload", ["mload", i_placeholder]],
                [
                    "mstore", i_placeholder,
                    ["add", ["mload", i_placeholder], 32]
                ],
                [
                    "if", ["lt", ["mload", i_placeholder], mem_to],
                    ["goto", push_loop_label]
                ],
            ]
            pop_local_vars = [
                ["mstore", i_placeholder, mem_to - 32],
                ["label", pop_loop_label],
                ["mstore", ["mload", i_placeholder], "pass"],
                [
                    "mstore", i_placeholder,
                    ["sub", ["mload", i_placeholder], 32]
                ],
                [
                    "if", ["ge", ["mload", i_placeholder], mem_from],
                    ["goto", pop_loop_label]
                ],
            ]
        else:
            # for smaller memory, hardcode the mload/mstore locations
            push_mem_slots = []
            for pos, size in var_slots:
                push_mem_slots.extend([pos + i * 32 for i in range(size)])

            push_local_vars = [["mload", pos] for pos in push_mem_slots]
            pop_local_vars = [["mstore", pos, "pass"]
                              for pos in push_mem_slots[::-1]]

    # Push Arguments
    if expr_args:
        inargs, inargsize, arg_pos = pack_arguments(sig,
                                                    expr_args,
                                                    context,
                                                    stmt_expr,
                                                    is_external_call=False)
        push_args += [
            inargs
        ]  # copy arguments first, to not mess up the push/pop sequencing.

        static_arg_size = 32 * sum(
            [get_static_size_of_type(arg.typ) for arg in expr_args])
        static_pos = int(arg_pos + static_arg_size)
        needs_dyn_section = any(
            [has_dynamic_data(arg.typ) for arg in expr_args])

        if needs_dyn_section:
            ident = f"push_args_{sig.method_id}_{stmt_expr.lineno}_{stmt_expr.col_offset}"
            start_label = ident + "_start"
            end_label = ident + "_end"
            i_placeholder = context.new_internal_variable(BaseType("uint256"))

            # Calculate copy start position.
            # Given | static | dynamic | section in memory,
            # copy backwards so the values are in order on the stack.
            # We calculate i, the end of the whole encoded part
            # (i.e. the starting index for copy)
            # by taking ceil32(len<arg>) + offset<arg> + arg_pos
            # for the last dynamic argument and arg_pos is the start
            # the whole argument section.
            idx = 0
            for arg in expr_args:
                if isinstance(arg.typ, ByteArrayLike):
                    last_idx = idx
                idx += get_static_size_of_type(arg.typ)
            push_args += [[
                "with",
                "offset",
                ["mload", arg_pos + last_idx * 32],
                [
                    "with",
                    "len_pos",
                    ["add", arg_pos, "offset"],
                    [
                        "with",
                        "len_value",
                        ["mload", "len_pos"],
                        [
                            "mstore", i_placeholder,
                            ["add", "len_pos", ["ceil32", "len_value"]]
                        ],
                    ],
                ],
            ]]
            # loop from end of dynamic section to start of dynamic section,
            # pushing each element onto the stack.
            push_args += [
                ["label", start_label],
                [
                    "if", ["lt", ["mload", i_placeholder], static_pos],
                    ["goto", end_label]
                ],
                ["mload", ["mload", i_placeholder]],
                [
                    "mstore", i_placeholder,
                    ["sub", ["mload", i_placeholder], 32]
                ],  # decrease i
                ["goto", start_label],
                ["label", end_label],
            ]

        # push static section
        push_args += [["mload", pos]
                      for pos in reversed(range(arg_pos, static_pos, 32))]
    elif sig.args:
        raise StructureException(
            f"Wrong number of args for: {sig.name} (0 args given, expected {len(sig.args)})",
            stmt_expr,
        )

    # Jump to function label.
    jump_to_func = [
        ["add", ["pc"], 6],  # set callback pointer.
        ["goto", f"priv_{sig.method_id}"],
        ["jumpdest"],
    ]

    # Pop return values.
    returner = [0]
    if sig.output_type:
        output_placeholder, returner, output_size = _call_make_placeholder(
            stmt_expr, context, sig)
        if output_size > 0:
            dynamic_offsets = []
            if isinstance(sig.output_type, (BaseType, ListType)):
                pop_return_values = [[
                    "mstore", ["add", output_placeholder, pos], "pass"
                ] for pos in range(0, output_size, 32)]
            elif isinstance(sig.output_type, ByteArrayLike):
                dynamic_offsets = [(0, sig.output_type)]
                pop_return_values = [
                    ["pop", "pass"],
                ]
            elif isinstance(sig.output_type, TupleLike):
                static_offset = 0
                pop_return_values = []
                for name, typ in sig.output_type.tuple_items():
                    if isinstance(typ, ByteArrayLike):
                        pop_return_values.append([
                            "mstore",
                            ["add", output_placeholder, static_offset], "pass"
                        ])
                        dynamic_offsets.append(([
                            "mload",
                            ["add", output_placeholder, static_offset]
                        ], name))
                        static_offset += 32
                    else:
                        member_output_size = get_size_of_type(typ) * 32
                        pop_return_values.extend([[
                            "mstore", ["add", output_placeholder, pos], "pass"
                        ] for pos in range(static_offset, static_offset +
                                           member_output_size, 32)])
                        static_offset += member_output_size

            # append dynamic unpacker.
            dyn_idx = 0
            for in_memory_offset, _out_type in dynamic_offsets:
                ident = f"{stmt_expr.lineno}_{stmt_expr.col_offset}_arg_{dyn_idx}"
                dyn_idx += 1
                start_label = "dyn_unpack_start_" + ident
                end_label = "dyn_unpack_end_" + ident
                i_placeholder = context.new_internal_variable(
                    typ=BaseType("uint256"))
                begin_pos = ["add", output_placeholder, in_memory_offset]
                # loop until length.
                o = LLLnode.from_list(
                    [
                        "seq_unchecked",
                        ["mstore", begin_pos, "pass"],  # get len
                        ["mstore", i_placeholder, 0],
                        ["label", start_label],
                        [  # break
                            "if",
                            [
                                "ge", ["mload", i_placeholder],
                                ["ceil32", ["mload", begin_pos]]
                            ],
                            ["goto", end_label],
                        ],
                        [  # pop into correct memory slot.
                            "mstore",
                            [
                                "add", ["add", begin_pos, 32],
                                ["mload", i_placeholder]
                            ],
                            "pass",
                        ],
                        # increment i
                        [
                            "mstore", i_placeholder,
                            ["add", 32, ["mload", i_placeholder]]
                        ],
                        ["goto", start_label],
                        ["label", end_label],
                    ],
                    typ=None,
                    annotation="dynamic unpacker",
                    pos=getpos(stmt_expr),
                )
                pop_return_values.append(o)

    call_body = list(
        itertools.chain(
            ["seq_unchecked"],
            pre_init,
            push_local_vars,
            push_args,
            jump_to_func,
            pop_return_values,
            pop_local_vars,
            [returner],
        ))
    # If we have no return, we need to pop off
    pop_returner_call_body = ["pop", call_body
                              ] if sig.output_type is None else call_body

    o = LLLnode.from_list(
        pop_returner_call_body,
        typ=sig.output_type,
        location="memory",
        pos=getpos(stmt_expr),
        annotation=f"Internal Call: {method_name}",
        add_gas_estimate=sig.gas,
    )
    o.gas += sig.gas
    return o
def parse_external_function(
    code: vy_ast.FunctionDef,
    sig: FunctionSignature,
    context: Context,
    check_nonpayable: bool,
) -> LLLnode:
    """
    Parse a external function (FuncDef), and produce full function body.

    :param sig: the FuntionSignature
    :param code: ast of function
    :param check_nonpayable: if True, include a check that `msg.value == 0`
                             at the beginning of the function
    :return: full sig compare & function body
    """

    func_type = code._metadata["type"]

    # Get nonreentrant lock
    nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(
        func_type, context.global_ctx)

    clampers = []

    # Generate copiers
    copier: List[Any] = ["pass"]
    if not len(sig.base_args):
        copier = ["pass"]
    elif sig.name == "__init__":
        copier = [
            "codecopy", MemoryPositions.RESERVED_MEMORY, "~codelen",
            sig.base_copy_size
        ]
        context.memory_allocator.expand_memory(sig.max_copy_size)
    clampers.append(copier)

    if check_nonpayable and sig.mutability != "payable":
        # if the contract contains payable functions, but this is not one of them
        # add an assertion that the value of the call is zero
        clampers.append(["assert", ["iszero", "callvalue"]])

    # Fill variable positions
    default_args_start_pos = len(sig.base_args)
    for i, arg in enumerate(sig.args):
        if i < len(sig.base_args):
            clampers.append(
                make_arg_clamper(
                    arg.pos,
                    context.memory_allocator.get_next_memory_position(),
                    arg.typ,
                    sig.name == "__init__",
                ))
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos = context.memory_allocator.expand_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            if sig.name == "__init__":
                context.vars[arg.name] = VariableRecord(
                    arg.name,
                    MemoryPositions.RESERVED_MEMORY + arg.pos,
                    arg.typ,
                    False,
                )
            elif i >= default_args_start_pos:  # default args need to be allocated in memory.
                type_size = get_size_of_type(arg.typ) * 32
                default_arg_pos = context.memory_allocator.expand_memory(
                    type_size)
                context.vars[arg.name] = VariableRecord(
                    name=arg.name,
                    pos=default_arg_pos,
                    typ=arg.typ,
                    mutable=False,
                )
            else:
                context.vars[arg.name] = VariableRecord(name=arg.name,
                                                        pos=4 + arg.pos,
                                                        typ=arg.typ,
                                                        mutable=False,
                                                        location="calldata")

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == "__init__":
        o = LLLnode.from_list(
            ["seq"] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is default function.
    elif sig.is_default_func():
        o = LLLnode.from_list(
            ["seq"] + clampers + [parse_body(code.body, context)] +
            [["stop"]],  # type: ignore
            pos=getpos(code),
        )
    # Is a normal function.
    else:
        # Function with default parameters.
        if sig.total_default_args > 0:
            function_routine = f"{sig.name}_{sig.method_id}"
            default_sigs = sig_utils.generate_default_arg_sigs(
                code, context.sigs, context.global_ctx)
            sig_chain: List[Any] = ["seq"]

            for default_sig in default_sigs:
                sig_compare, _ = get_sig_statements(default_sig, getpos(code))

                # Populate unset default variables
                set_defaults = []
                for arg_name in get_default_names_to_set(sig, default_sig):
                    value = Expr(sig.default_values[arg_name],
                                 context).lll_node
                    var = context.vars[arg_name]
                    left = LLLnode.from_list(
                        var.pos,
                        typ=var.typ,
                        location="memory",
                        pos=getpos(code),
                        mutable=var.mutable,
                    )
                    set_defaults.append(
                        make_setter(left, value, "memory", pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in sig.base_args}
                copier_arg_count = len(default_sig.args) - len(sig.base_args)
                copier_arg_names = list(current_sig_arg_names - base_arg_names)

                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers: List[Any] = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)

                    # Copy default parameters from calldata.
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]

                        # Add clampers.
                        default_copiers.append(
                            make_arg_clamper(
                                calldata_offset - 4,
                                var.pos,
                                var.typ,
                            ))
                        # Add copying code.
                        _offset: Union[int, List[Any]] = calldata_offset
                        if isinstance(var.typ, ByteArrayLike):
                            _offset = [
                                "add", 4, ["calldataload", calldata_offset]
                            ]
                        default_copiers.append(
                            get_external_arg_copier(
                                memory_dest=var.pos,
                                total_size=var.size * 32,
                                offset=_offset,
                            ))

                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    "if",
                    sig_compare,
                    [
                        "seq",
                        ["seq"] + set_defaults if set_defaults else ["pass"],
                        ["seq_unchecked"] +
                        default_copiers if default_copiers else ["pass"],
                        ["goto", function_routine],
                    ],
                ])

            # Function with default parameters.
            function_jump_label = f"{sig.name}_{sig.method_id}_skip"
            o = LLLnode.from_list(
                [
                    "seq",
                    sig_chain,
                    [
                        "seq",
                        ["goto", function_jump_label],
                        ["label", function_routine],
                        ["seq"] + nonreentrant_pre + clampers +
                        [parse_body(c, context)
                         for c in code.body] + nonreentrant_post + [["stop"]],
                        ["label", function_jump_label],
                    ],
                ],
                typ=None,
                pos=getpos(code),
            )

        else:
            # Function without default parameters.
            sig_compare, _ = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list(
                [
                    "if",
                    sig_compare,
                    ["seq"] + nonreentrant_pre + clampers +
                    [parse_body(c, context)
                     for c in code.body] + nonreentrant_post + [["stop"]],
                ],
                typ=None,
                pos=getpos(code),
            )
    return o
Exemple #29
0
def parse_func(code, _globals, sigs, origcode, _custom_units, _vars=None):
    if _vars is None:
        _vars = {}
    sig = FunctionSignature.from_definition(code, sigs=sigs, custom_units=_custom_units)
    # Check for duplicate variables with globals
    for arg in sig.args:
        if arg.name in _globals:
            raise FunctionDeclarationException("Variable name duplicated between function arguments and globals: " + arg.name)
    # Create a context
    context = Context(vars=_vars, globals=_globals, sigs=sigs,
                      return_type=sig.output_type, is_constant=sig.const, is_payable=sig.payable, origcode=origcode, custom_units=_custom_units)
    # Copy calldata to memory for fixed-size arguments
    copy_size = sum([32 if isinstance(arg.typ, ByteArrayType) else get_size_of_type(arg.typ) * 32 for arg in sig.args])
    context.next_mem += copy_size
    if not len(sig.args):
        copier = 'pass'
    elif sig.name == '__init__':
        copier = ['codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen', copy_size]
    else:
        copier = ['calldatacopy', MemoryPositions.RESERVED_MEMORY, 4, copy_size]
    clampers = [copier]
    # Add asserts for payable and internal
    if not sig.payable:
        clampers.append(['assert', ['iszero', 'callvalue']])
    if sig.private:
        clampers.append(['assert', ['eq', 'caller', 'address']])
    # Fill in variable positions
    for arg in sig.args:
        clampers.append(make_clamper(arg.pos, context.next_mem, arg.typ, sig.name == '__init__'))
        if isinstance(arg.typ, ByteArrayType):
            context.vars[arg.name] = VariableRecord(arg.name, context.next_mem, arg.typ, False)
            context.next_mem += 32 * get_size_of_type(arg.typ)
        else:
            context.vars[arg.name] = VariableRecord(arg.name, MemoryPositions.RESERVED_MEMORY + arg.pos, arg.typ, False)
    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == '__init__':
        o = LLLnode.from_list(['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code))
    elif is_default_func(sig):
        if len(sig.args) > 0:
            raise FunctionDeclarationException('Default function may not receive any arguments.', code)
        if sig.private:
            raise FunctionDeclarationException('Default function may only be public.', code)
        o = LLLnode.from_list(['seq'] + clampers + [parse_body(code.body, context)], pos=getpos(code))
    else:
        method_id_node = LLLnode.from_list(sig.method_id, pos=getpos(code), annotation='%s' % sig.name)
        o = LLLnode.from_list(['if',
                                  ['eq', ['mload', 0], method_id_node],
                                  ['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop']
                               ], typ=None, pos=getpos(code))

    # Check for at leasts one return statement if necessary.
    if context.return_type and context.function_return_count == 0:
        raise FunctionDeclarationException(
            "Missing return statement in function '%s' " % sig.name, code
        )

    o.context = context
    o.total_gas = o.gas + calc_mem_gas(o.context.next_mem)
    o.func_name = sig.name
    return o