Пример #1
0
def _get_external_signatures(global_ctx, sig_formatter=lambda x: x):
    ret = []

    for func_ast in global_ctx._function_defs:
        sig = FunctionSignature.from_definition(func_ast, global_ctx)
        if not sig.internal:
            ret.append(sig_formatter(sig))
    return ret
Пример #2
0
def generate_ir_for_module(
        global_ctx: GlobalContext
) -> Tuple[IRnode, IRnode, FunctionSignatures]:
    # order functions so that each function comes after all of its callees
    function_defs = _topsort(global_ctx._function_defs)

    # FunctionSignatures for all interfaces defined in this module
    all_sigs: Dict[str, FunctionSignatures] = {}
    if global_ctx._contracts or global_ctx._interfaces:
        all_sigs = parse_external_interfaces(all_sigs, global_ctx)

    init_function: Optional[vy_ast.FunctionDef] = None
    sigs: FunctionSignatures = {}

    # generate all signatures
    # TODO really this should live in GlobalContext
    for f in function_defs:
        sig = FunctionSignature.from_definition(f, global_ctx)
        # add it to the global namespace.
        sigs[sig.name] = sig
        # a little hacky, eventually FunctionSignature should be
        # merged with ContractFunction and we can remove this.
        f._metadata["signature"] = sig

    assert "self" not in all_sigs
    all_sigs["self"] = sigs

    runtime_functions = [f for f in function_defs if not _is_init_func(f)]
    init_function = next((f for f in function_defs if _is_init_func(f)), None)

    runtime, internal_functions = _runtime_ir(runtime_functions, all_sigs,
                                              global_ctx)

    deploy_code: List[Any] = ["seq"]
    immutables_len = global_ctx.immutable_section_bytes
    if init_function:
        init_func_ir = generate_ir_for_function(init_function, all_sigs,
                                                global_ctx, False)
        deploy_code.append(init_func_ir)

        # pass the amount of memory allocated for the init function
        # so that deployment does not clobber while preparing immutables
        # note: (deploy mem_ofst, code, extra_padding)
        init_mem_used = init_function._metadata[
            "signature"].frame_info.mem_used
        deploy_code.append(["deploy", init_mem_used, runtime, immutables_len])

        # internal functions come after everything else
        for f in init_function._metadata["type"].called_functions:
            deploy_code.append(internal_functions[f.name])

    else:
        if immutables_len != 0:
            raise CompilerPanic("unreachable")
        deploy_code.append(["deploy", 0, runtime, 0])

    return IRnode.from_list(deploy_code), IRnode.from_list(runtime), sigs
Пример #3
0
def _get_external_signatures(global_ctx, sig_formatter=lambda x: x):
    ret = []

    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_structs=global_ctx._structs,
        )
        if not sig.internal:
            ret.append(sig_formatter(sig))
    return ret
Пример #4
0
def generate_default_arg_sigs(code, interfaces, global_ctx):
    # generate all sigs, and attach.
    total_default_args = len(code.args.defaults)
    if total_default_args == 0:
        return [
            FunctionSignature.from_definition(
                code,
                sigs=interfaces,
                custom_structs=global_ctx._structs,
            )
        ]
    base_args = code.args.args[:-total_default_args]
    default_args = code.args.args[-total_default_args:]

    # Generate a list of default function combinations.
    row = [False] * (total_default_args)
    table = [row.copy()]
    for i in range(total_default_args):
        row[i] = True
        table.append(row.copy())

    default_sig_strs = []
    sig_fun_defs = []
    for truth_row in table:
        new_code = copy.deepcopy(code)
        new_code.args.args = copy.deepcopy(base_args)
        new_code.args.default = []
        # Add necessary default args.
        for idx, val in enumerate(truth_row):
            if val is True:
                new_code.args.args.append(default_args[idx])
        sig = FunctionSignature.from_definition(
            new_code,
            sigs=interfaces,
            custom_structs=global_ctx._structs,
        )
        default_sig_strs.append(sig.sig)
        sig_fun_defs.append(sig)

    return sig_fun_defs
Пример #5
0
def mk_full_signature_from_json(abi):
    funcs = [func for func in abi if func["type"] == "function"]
    sigs = []

    for func in funcs:
        args = []
        returns = None
        for a in func["inputs"]:
            arg = vy_ast.arg(
                arg=a["name"],
                annotation=abi_type_to_ast(a["type"], 1048576),
                lineno=0,
                col_offset=0,
            )
            args.append(arg)

        if len(func["outputs"]) == 1:
            returns = abi_type_to_ast(func["outputs"][0]["type"], 1)
        elif len(func["outputs"]) > 1:
            returns = vy_ast.Tuple(elements=[
                abi_type_to_ast(a["type"], 1) for a in func["outputs"]
            ])

        decorator_list = [vy_ast.Name(id="external")]
        # Handle either constant/payable or stateMutability field
        if ("constant" in func
                and func["constant"]) or ("stateMutability" in func and
                                          func["stateMutability"] == "view"):
            decorator_list.append(vy_ast.Name(id="view"))
        if ("payable" in func
                and func["payable"]) or ("stateMutability" in func and
                                         func["stateMutability"] == "payable"):
            decorator_list.append(vy_ast.Name(id="payable"))

        sig = FunctionSignature.from_definition(
            vy_ast.FunctionDef(
                name=func["name"],
                args=vy_ast.arguments(args=args),
                decorator_list=decorator_list,
                returns=returns,
            ),
            GlobalContext(),  # dummy
            is_from_json=True,
        )
        sigs.append(sig)
    return sigs
Пример #6
0
def mk_full_signature(global_ctx, sig_formatter):
    o = []

    # Produce function signatures.
    for code in global_ctx._defs:
        sig = FunctionSignature.from_definition(
            code,
            sigs=global_ctx._contracts,
            custom_structs=global_ctx._structs,
        )
        if not sig.internal:
            default_sigs = generate_default_arg_sigs(code,
                                                     global_ctx._contracts,
                                                     global_ctx)
            for s in default_sigs:
                o.append(sig_formatter(s))
    return o
Пример #7
0
def parse_external_interfaces(external_interfaces, global_ctx):
    for _interfacename in global_ctx._contracts:
        # TODO factor me into helper function
        _interface_defs = global_ctx._contracts[_interfacename]
        _defnames = [_def.name for _def in _interface_defs]
        interface = {}
        if len(set(_defnames)) < len(_interface_defs):
            raise FunctionDeclarationException(
                "Duplicate function name: "
                f"{[name for name in _defnames if _defnames.count(name) > 1][0]}"
            )

        for _def in _interface_defs:
            constant = False
            # test for valid call type keyword.
            if (len(_def.body) == 1 and isinstance(_def.body[0], vy_ast.Expr)
                    and isinstance(_def.body[0].value, vy_ast.Name)
                    # NOTE: Can't import enums here because of circular import
                    and _def.body[0].value.id
                    in ("pure", "view", "nonpayable", "payable")):
                constant = True if _def.body[0].value.id in ("view",
                                                             "pure") else False
            else:
                raise StructureException(
                    "state mutability of call type must be specified", _def)
            # Recognizes already-defined structs
            sig = FunctionSignature.from_definition(
                _def,
                sigs=global_ctx.interface_names,
                interface_def=True,
                constant_override=constant,
                custom_structs=global_ctx._structs,
            )
            interface[sig.name] = sig
        external_interfaces[_interfacename] = interface

    for interface_name, interface in global_ctx._interfaces.items():
        external_interfaces[interface_name] = {
            sig.name: sig
            for sig in interface if isinstance(sig, FunctionSignature)
        }

    return external_interfaces
Пример #8
0
def parse_regular_functions(regular_functions, sigs, external_interfaces,
                            global_ctx, default_function, init_function):
    # check for payable/nonpayable external functions to optimize nonpayable assertions
    func_types = [i._metadata["type"] for i in global_ctx._defs]
    mutabilities = [
        i.mutability for i in func_types
        if i.visibility == FunctionVisibility.EXTERNAL
    ]
    has_payable = any(i == StateMutability.PAYABLE for i in mutabilities)
    has_nonpayable = any(i != StateMutability.PAYABLE for i in mutabilities)

    is_default_payable = (default_function is not None
                          and default_function._metadata["type"].mutability
                          == StateMutability.PAYABLE)

    # TODO streamline the nonpayable check logic

    # when a contract has a payable default function and at least one nonpayable
    # external function, we must perform the nonpayable check on every function
    check_per_function = is_default_payable and has_nonpayable

    # generate IR for regular functions
    payable_funcs = []
    nonpayable_funcs = []
    internal_funcs = []
    add_gas = 0

    for func_node in regular_functions:
        func_type = func_node._metadata["type"]
        func_ir, frame_start, frame_size = generate_ir_for_function(
            func_node, {
                **{
                    "self": sigs
                },
                **external_interfaces
            }, global_ctx, check_per_function)

        if func_type.visibility == FunctionVisibility.INTERNAL:
            internal_funcs.append(func_ir)

        elif func_type.mutability == StateMutability.PAYABLE:
            add_gas += 30  # CMC 20210910 why?
            payable_funcs.append(func_ir)

        else:
            add_gas += 30  # CMC 20210910 why?
            nonpayable_funcs.append(func_ir)

        func_ir.total_gas += add_gas

        # update sigs with metadata gathered from compiling the function so that
        # we can handle calls to self
        # TODO we only need to do this for internal functions; external functions
        # cannot be called via `self`
        sig = FunctionSignature.from_definition(func_node, external_interfaces,
                                                global_ctx._structs)
        sig.gas = func_ir.total_gas
        sig.frame_start = frame_start
        sig.frame_size = frame_size
        sigs[sig.name] = sig

    # generate IR for fallback function
    if default_function:
        fallback_ir, _frame_start, _frame_size = generate_ir_for_function(
            default_function,
            {
                **{
                    "self": sigs
                },
                **external_interfaces
            },
            global_ctx,
            # include a nonpayble check here if the contract only has a default function
            check_per_function or not regular_functions,
        )
    else:
        fallback_ir = IRnode.from_list(["revert", 0, 0],
                                       typ=None,
                                       annotation="Default function")

    if check_per_function:
        external_seq = ["seq"] + payable_funcs + nonpayable_funcs

    else:
        # payable functions are placed prior to nonpayable functions
        # and seperated by a nonpayable assertion
        external_seq = ["seq"]
        if has_payable:
            external_seq += payable_funcs
        if has_nonpayable:
            external_seq.append(["assert", ["iszero", "callvalue"]])
            external_seq += nonpayable_funcs

    # ensure the external jumptable section gets closed out
    # (for basic block hygiene and also for zksync interpreter)
    # NOTE: this jump gets optimized out in assembly since the
    # fallback label is the immediate next instruction,
    close_selector_section = ["goto", "fallback"]

    # bytecode is organized by: external functions, fallback fn, internal functions
    # this way we save gas and reduce bytecode by not jumping over internal functions
    runtime = [
        "seq",
        # check that calldatasize is at least 4, otherwise
        # calldataload will load zeros (cf. yellow paper).
        ["if", ["lt", "calldatasize", 4], ["goto", "fallback"]],
        [
            "with", "_calldata_method_id",
            shr(224, ["calldataload", 0]), external_seq
        ],
        close_selector_section,
        ["label", "fallback", ["var_list"], fallback_ir],
    ]
    runtime.extend(internal_funcs)

    return runtime
Пример #9
0
def make_call(stmt_expr, context):
    # ** Internal Call **
    # Steps:
    # (x) push current local variables
    # (x) push arguments
    # (x) push jumpdest (callback ptr)
    # (x) jump to label
    # (x) pop return values
    # (x) pop local variables

    pop_local_vars = []
    push_local_vars = []
    pop_return_values = []
    push_args = []
    method_name = stmt_expr.func.attr

    # TODO check this out
    from vyper.old_codegen.expr import parse_sequence

    pre_init, expr_args = parse_sequence(stmt_expr, stmt_expr.args, context)
    sig = FunctionSignature.lookup_sig(
        context.sigs,
        method_name,
        expr_args,
        stmt_expr,
        context,
    )

    if context.is_constant() and sig.mutability not in ("view", "pure"):
        raise StateAccessViolation(
            f"May not call state modifying function "
            f"'{method_name}' within {context.pp_constancy()}.",
            getpos(stmt_expr),
        )

    if not sig.internal:
        raise StructureException("Cannot call external functions via 'self'",
                                 stmt_expr)

    # Push local variables.
    var_slots = [(v.pos, v.size) for name, v in context.vars.items()
                 if v.location == "memory"]
    if var_slots:
        var_slots.sort(key=lambda x: x[0])

        if len(var_slots) > 10:
            # if memory is large enough, push and pop it via iteration
            mem_from, mem_to = var_slots[0][
                0], var_slots[-1][0] + var_slots[-1][1] * 32
            i_placeholder = context.new_internal_variable(BaseType("uint256"))
            local_save_ident = f"_{stmt_expr.lineno}_{stmt_expr.col_offset}"
            push_loop_label = "save_locals_start" + local_save_ident
            pop_loop_label = "restore_locals_start" + local_save_ident
            push_local_vars = [
                ["mstore", i_placeholder, mem_from],
                ["label", push_loop_label],
                ["mload", ["mload", i_placeholder]],
                [
                    "mstore", i_placeholder,
                    ["add", ["mload", i_placeholder], 32]
                ],
                [
                    "if", ["lt", ["mload", i_placeholder], mem_to],
                    ["goto", push_loop_label]
                ],
            ]
            pop_local_vars = [
                ["mstore", i_placeholder, mem_to - 32],
                ["label", pop_loop_label],
                ["mstore", ["mload", i_placeholder], "pass"],
                [
                    "mstore", i_placeholder,
                    ["sub", ["mload", i_placeholder], 32]
                ],
                [
                    "if", ["ge", ["mload", i_placeholder], mem_from],
                    ["goto", pop_loop_label]
                ],
            ]
        else:
            # for smaller memory, hardcode the mload/mstore locations
            push_mem_slots = []
            for pos, size in var_slots:
                push_mem_slots.extend([pos + i * 32 for i in range(size)])

            push_local_vars = [["mload", pos] for pos in push_mem_slots]
            pop_local_vars = [["mstore", pos, "pass"]
                              for pos in push_mem_slots[::-1]]

    # Push Arguments
    if expr_args:
        inargs, inargsize, arg_pos = pack_arguments(sig,
                                                    expr_args,
                                                    context,
                                                    stmt_expr,
                                                    is_external_call=False)
        push_args += [
            inargs
        ]  # copy arguments first, to not mess up the push/pop sequencing.

        static_arg_size = 32 * sum(
            [get_static_size_of_type(arg.typ) for arg in expr_args])
        static_pos = int(arg_pos + static_arg_size)
        needs_dyn_section = any(
            [has_dynamic_data(arg.typ) for arg in expr_args])

        if needs_dyn_section:
            ident = f"push_args_{sig.method_id}_{stmt_expr.lineno}_{stmt_expr.col_offset}"
            start_label = ident + "_start"
            end_label = ident + "_end"
            i_placeholder = context.new_internal_variable(BaseType("uint256"))

            # Calculate copy start position.
            # Given | static | dynamic | section in memory,
            # copy backwards so the values are in order on the stack.
            # We calculate i, the end of the whole encoded part
            # (i.e. the starting index for copy)
            # by taking ceil32(len<arg>) + offset<arg> + arg_pos
            # for the last dynamic argument and arg_pos is the start
            # the whole argument section.
            idx = 0
            for arg in expr_args:
                if isinstance(arg.typ, ByteArrayLike):
                    last_idx = idx
                idx += get_static_size_of_type(arg.typ)
            push_args += [[
                "with",
                "offset",
                ["mload", arg_pos + last_idx * 32],
                [
                    "with",
                    "len_pos",
                    ["add", arg_pos, "offset"],
                    [
                        "with",
                        "len_value",
                        ["mload", "len_pos"],
                        [
                            "mstore", i_placeholder,
                            ["add", "len_pos", ["ceil32", "len_value"]]
                        ],
                    ],
                ],
            ]]
            # loop from end of dynamic section to start of dynamic section,
            # pushing each element onto the stack.
            push_args += [
                ["label", start_label],
                [
                    "if", ["lt", ["mload", i_placeholder], static_pos],
                    ["goto", end_label]
                ],
                ["mload", ["mload", i_placeholder]],
                [
                    "mstore", i_placeholder,
                    ["sub", ["mload", i_placeholder], 32]
                ],  # decrease i
                ["goto", start_label],
                ["label", end_label],
            ]

        # push static section
        push_args += [["mload", pos]
                      for pos in reversed(range(arg_pos, static_pos, 32))]
    elif sig.args:
        raise StructureException(
            f"Wrong number of args for: {sig.name} (0 args given, expected {len(sig.args)})",
            stmt_expr,
        )

    # Jump to function label.
    jump_to_func = [
        ["add", ["pc"], 6],  # set callback pointer.
        ["goto", f"priv_{sig.method_id}"],
        ["jumpdest"],
    ]

    # Pop return values.
    returner = [0]
    if sig.output_type:
        output_placeholder, returner, output_size = _call_make_placeholder(
            stmt_expr, context, sig)
        if output_size > 0:
            dynamic_offsets = []
            if isinstance(sig.output_type, (BaseType, ListType)):
                pop_return_values = [[
                    "mstore", ["add", output_placeholder, pos], "pass"
                ] for pos in range(0, output_size, 32)]
            elif isinstance(sig.output_type, ByteArrayLike):
                dynamic_offsets = [(0, sig.output_type)]
                pop_return_values = [
                    ["pop", "pass"],
                ]
            elif isinstance(sig.output_type, TupleLike):
                static_offset = 0
                pop_return_values = []
                for name, typ in sig.output_type.tuple_items():
                    if isinstance(typ, ByteArrayLike):
                        pop_return_values.append([
                            "mstore",
                            ["add", output_placeholder, static_offset], "pass"
                        ])
                        dynamic_offsets.append(([
                            "mload",
                            ["add", output_placeholder, static_offset]
                        ], name))
                        static_offset += 32
                    else:
                        member_output_size = get_size_of_type(typ) * 32
                        pop_return_values.extend([[
                            "mstore", ["add", output_placeholder, pos], "pass"
                        ] for pos in range(static_offset, static_offset +
                                           member_output_size, 32)])
                        static_offset += member_output_size

            # append dynamic unpacker.
            dyn_idx = 0
            for in_memory_offset, _out_type in dynamic_offsets:
                ident = f"{stmt_expr.lineno}_{stmt_expr.col_offset}_arg_{dyn_idx}"
                dyn_idx += 1
                start_label = "dyn_unpack_start_" + ident
                end_label = "dyn_unpack_end_" + ident
                i_placeholder = context.new_internal_variable(
                    typ=BaseType("uint256"))
                begin_pos = ["add", output_placeholder, in_memory_offset]
                # loop until length.
                o = LLLnode.from_list(
                    [
                        "seq_unchecked",
                        ["mstore", begin_pos, "pass"],  # get len
                        ["mstore", i_placeholder, 0],
                        ["label", start_label],
                        [  # break
                            "if",
                            [
                                "ge", ["mload", i_placeholder],
                                ["ceil32", ["mload", begin_pos]]
                            ],
                            ["goto", end_label],
                        ],
                        [  # pop into correct memory slot.
                            "mstore",
                            [
                                "add", ["add", begin_pos, 32],
                                ["mload", i_placeholder]
                            ],
                            "pass",
                        ],
                        # increment i
                        [
                            "mstore", i_placeholder,
                            ["add", 32, ["mload", i_placeholder]]
                        ],
                        ["goto", start_label],
                        ["label", end_label],
                    ],
                    typ=None,
                    annotation="dynamic unpacker",
                    pos=getpos(stmt_expr),
                )
                pop_return_values.append(o)

    call_body = list(
        itertools.chain(
            ["seq_unchecked"],
            pre_init,
            push_local_vars,
            push_args,
            jump_to_func,
            pop_return_values,
            pop_local_vars,
            [returner],
        ))
    # If we have no return, we need to pop off
    pop_returner_call_body = ["pop", call_body
                              ] if sig.output_type is None else call_body

    o = LLLnode.from_list(
        pop_returner_call_body,
        typ=sig.output_type,
        location="memory",
        pos=getpos(stmt_expr),
        annotation=f"Internal Call: {method_name}",
        add_gas_estimate=sig.gas,
    )
    o.gas += sig.gas
    return o
Пример #10
0
def parse_regular_functions(
    o, regular_functions, sigs, external_interfaces, global_ctx, default_function
):
    # check for payable/nonpayable external functions to optimize nonpayable assertions
    func_types = [i._metadata["type"] for i in global_ctx._defs]
    mutabilities = [i.mutability for i in func_types if i.visibility == FunctionVisibility.EXTERNAL]
    has_payable = any(i == StateMutability.PAYABLE for i in mutabilities)
    has_nonpayable = any(i != StateMutability.PAYABLE for i in mutabilities)

    is_default_payable = (
        default_function is not None
        and default_function._metadata["type"].mutability == StateMutability.PAYABLE
    )

    # TODO streamline the nonpayable check logic

    # when a contract has a payable default function and at least one nonpayable
    # external function, we must perform the nonpayable check on every function
    check_per_function = is_default_payable and has_nonpayable

    # generate LLL for regular functions
    payable_funcs = []
    nonpayable_funcs = []
    internal_funcs = []
    add_gas = func_init_lll().gas

    for func_node in regular_functions:
        func_type = func_node._metadata["type"]
        func_lll, frame_start, frame_size = generate_lll_for_function(
            func_node, {**{"self": sigs}, **external_interfaces}, global_ctx, check_per_function
        )

        if func_type.visibility == FunctionVisibility.INTERNAL:
            internal_funcs.append(func_lll)

        elif func_type.mutability == StateMutability.PAYABLE:
            add_gas += 30  # CMC 20210910 why?
            payable_funcs.append(func_lll)

        else:
            add_gas += 30  # CMC 20210910 why?
            nonpayable_funcs.append(func_lll)

        func_lll.total_gas += add_gas

        # update sigs with metadata gathered from compiling the function so that
        # we can handle calls to self
        # TODO we only need to do this for internal functions; external functions
        # cannot be called via `self`
        sig = FunctionSignature.from_definition(func_node, external_interfaces, global_ctx._structs)
        sig.gas = func_lll.total_gas
        sig.frame_start = frame_start
        sig.frame_size = frame_size
        sigs[sig.name] = sig

    # generate LLL for fallback function
    if default_function:
        fallback_lll, _frame_start, _frame_size = generate_lll_for_function(
            default_function,
            {**{"self": sigs}, **external_interfaces},
            global_ctx,
            # include a nonpayble check here if the contract only has a default function
            check_per_function or not regular_functions,
        )
    else:
        fallback_lll = LLLnode.from_list(["revert", 0, 0], typ=None, annotation="Default function")

    if check_per_function:
        external_seq = ["seq"] + payable_funcs + nonpayable_funcs

    else:
        # payable functions are placed prior to nonpayable functions
        # and seperated by a nonpayable assertion
        external_seq = ["seq"]
        if has_payable:
            external_seq += payable_funcs
        if has_nonpayable:
            external_seq.append(["assert", ["iszero", "callvalue"]])
            external_seq += nonpayable_funcs

    # bytecode is organized by: external functions, fallback fn, internal functions
    # this way we save gas and reduce bytecode by not jumping over internal functions
    runtime = [
        "seq",
        func_init_lll(),
        ["with", "_calldata_method_id", ["mload", 0], external_seq],
        ["seq", ["label", "fallback"], fallback_lll],
    ]
    runtime.extend(internal_funcs)

    immutables = [_global for _global in global_ctx._globals.values() if _global.is_immutable]

    # TODO: enable usage of the data section beyond just user defined immutables
    # https://github.com/vyperlang/vyper/pull/2466#discussion_r722816358
    if len(immutables) > 0:
        # find position of the last immutable so we do not overwrite it in memory
        # when we codecopy the runtime code to memory
        immutables = sorted(immutables, key=lambda imm: imm.pos)
        start_pos = immutables[-1].pos + immutables[-1].size * 32

        # create sequence of actions to copy immutables to the end of the runtime code in memory
        # TODO: if possible, just use identity precompile
        data_section = []
        for immutable in immutables:
            # store each immutable at the end of the runtime code
            memory_loc, offset = (
                immutable.pos,
                immutable.data_offset,
            )
            lhs = LLLnode.from_list(
                ["add", start_pos + offset, "_lllsz"], typ=immutable.typ, location="memory"
            )
            rhs = LLLnode.from_list(memory_loc, typ=immutable.typ, location="memory")
            data_section.append(make_setter(lhs, rhs, pos=None))

        # TODO: use GlobalContext.immutable_section_size
        data_section_size = sum([immutable.size * 32 for immutable in immutables])
        o.append(
            [
                "with",
                "_lllsz",  # keep size of runtime bytecode in sz var
                ["lll", start_pos, runtime],  # store runtime code at `start_pos`
                # sequence of copying immutables, with final action of returning the runtime code
                ["seq", *data_section, ["return", start_pos, ["add", data_section_size, "_lllsz"]]],
            ]
        )

    else:
        # NOTE: lll macro first argument is the location in memory to store
        # the compiled bytecode
        # https://lll-docs.readthedocs.io/en/latest/lll_reference.html#code-lll
        o.append(["return", 0, ["lll", 0, runtime]])

    return o, runtime
Пример #11
0
def parse_external_function(
    code: vy_ast.FunctionDef,
    sig: FunctionSignature,
    context: Context,
    check_nonpayable: bool,
) -> LLLnode:
    """
    Parse a external function (FuncDef), and produce full function body.

    :param sig: the FuntionSignature
    :param code: ast of function
    :param check_nonpayable: if True, include a check that `msg.value == 0`
                             at the beginning of the function
    :return: full sig compare & function body
    """

    func_type = code._metadata["type"]

    # Get nonreentrant lock
    nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_type)

    clampers = []

    # Generate copiers
    copier: List[Any] = ["pass"]
    if not len(sig.base_args):
        copier = ["pass"]
    elif sig.name == "__init__":
        copier = [
            "codecopy", MemoryPositions.RESERVED_MEMORY, "~codelen",
            sig.base_copy_size
        ]
        context.memory_allocator.expand_memory(sig.max_copy_size)
    clampers.append(copier)

    if check_nonpayable and sig.mutability != "payable":
        # if the contract contains payable functions, but this is not one of them
        # add an assertion that the value of the call is zero
        clampers.append(["assert", ["iszero", "callvalue"]])

    # Fill variable positions
    default_args_start_pos = len(sig.base_args)
    for i, arg in enumerate(sig.args):
        if i < len(sig.base_args):
            clampers.append(
                make_arg_clamper(
                    arg.pos,
                    context.memory_allocator.get_next_memory_position(),
                    arg.typ,
                    sig.name == "__init__",
                ))
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos = context.memory_allocator.expand_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            if sig.name == "__init__":
                context.vars[arg.name] = VariableRecord(
                    arg.name,
                    MemoryPositions.RESERVED_MEMORY + arg.pos,
                    arg.typ,
                    False,
                )
            elif i >= default_args_start_pos:  # default args need to be allocated in memory.
                type_size = get_size_of_type(arg.typ) * 32
                default_arg_pos = context.memory_allocator.expand_memory(
                    type_size)
                context.vars[arg.name] = VariableRecord(
                    name=arg.name,
                    pos=default_arg_pos,
                    typ=arg.typ,
                    mutable=False,
                )
            else:
                context.vars[arg.name] = VariableRecord(name=arg.name,
                                                        pos=4 + arg.pos,
                                                        typ=arg.typ,
                                                        mutable=False,
                                                        location="calldata")

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == "__init__":
        o = LLLnode.from_list(
            ["seq"] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is default function.
    elif sig.is_default_func():
        o = LLLnode.from_list(
            ["seq"] + clampers + [parse_body(code.body, context)] +
            [["stop"]],  # type: ignore
            pos=getpos(code),
        )
    # Is a normal function.
    else:
        # Function with default parameters.
        if sig.total_default_args > 0:
            function_routine = f"{sig.name}_{sig.method_id}"
            default_sigs = sig_utils.generate_default_arg_sigs(
                code, context.sigs, context.global_ctx)
            sig_chain: List[Any] = ["seq"]

            for default_sig in default_sigs:
                sig_compare, _ = get_sig_statements(default_sig, getpos(code))

                # Populate unset default variables
                set_defaults = []
                for arg_name in get_default_names_to_set(sig, default_sig):
                    value = Expr(sig.default_values[arg_name],
                                 context).lll_node
                    var = context.vars[arg_name]
                    left = LLLnode.from_list(
                        var.pos,
                        typ=var.typ,
                        location="memory",
                        pos=getpos(code),
                        mutable=var.mutable,
                    )
                    set_defaults.append(
                        make_setter(left, value, "memory", pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in sig.base_args}
                copier_arg_count = len(default_sig.args) - len(sig.base_args)
                copier_arg_names = list(current_sig_arg_names - base_arg_names)

                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers: List[Any] = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)

                    # Copy default parameters from calldata.
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]

                        # Add clampers.
                        default_copiers.append(
                            make_arg_clamper(
                                calldata_offset - 4,
                                var.pos,
                                var.typ,
                            ))
                        # Add copying code.
                        _offset: Union[int, List[Any]] = calldata_offset
                        if isinstance(var.typ, ByteArrayLike):
                            _offset = [
                                "add", 4, ["calldataload", calldata_offset]
                            ]
                        default_copiers.append(
                            get_external_arg_copier(
                                memory_dest=var.pos,
                                total_size=var.size * 32,
                                offset=_offset,
                            ))

                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    "if",
                    sig_compare,
                    [
                        "seq",
                        ["seq"] + set_defaults if set_defaults else ["pass"],
                        ["seq_unchecked"] +
                        default_copiers if default_copiers else ["pass"],
                        ["goto", function_routine],
                    ],
                ])

            # Function with default parameters.
            function_jump_label = f"{sig.name}_{sig.method_id}_skip"
            o = LLLnode.from_list(
                [
                    "seq",
                    sig_chain,
                    [
                        "seq",
                        ["goto", function_jump_label],
                        ["label", function_routine],
                        ["seq"] + nonreentrant_pre + clampers +
                        [parse_body(c, context)
                         for c in code.body] + nonreentrant_post + [["stop"]],
                        ["label", function_jump_label],
                    ],
                ],
                typ=None,
                pos=getpos(code),
            )

        else:
            # Function without default parameters.
            sig_compare, _ = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list(
                [
                    "if",
                    sig_compare,
                    ["seq"] + nonreentrant_pre + clampers +
                    [parse_body(c, context)
                     for c in code.body] + nonreentrant_post + [["stop"]],
                ],
                typ=None,
                pos=getpos(code),
            )
    return o