Ejemplo n.º 1
0
def parse_other_functions(o,
                          otherfuncs,
                          sigs,
                          external_contracts,
                          origcode,
                          global_ctx,
                          default_function):
    sub = ['seq', func_init_lll()]
    add_gas = func_init_lll().gas

    for _def in otherfuncs:
        sub.append(
            parse_function(_def, {**{'self': sigs}, **external_contracts}, origcode, global_ctx)
        )
        sub[-1].total_gas += add_gas
        add_gas += 30
        for sig in sig_utils.generate_default_arg_sigs(_def, external_contracts, global_ctx):
            sig.gas = sub[-1].total_gas
            sigs[sig.sig] = sig

    # Add fallback function
    if default_function:
        default_func = parse_function(
            default_function[0],
            {**{'self': sigs}, **external_contracts},
            origcode,
            global_ctx,
        )
        fallback = default_func
    else:
        fallback = LLLnode.from_list(['revert', 0, 0], typ=None, annotation='Default function')
    sub.append(['seq_unchecked', ['label', 'fallback'], fallback])
    o.append(['return', 0, ['lll', sub, 0]])
    return o, sub
Ejemplo n.º 2
0
def parse_other_functions(
    o, otherfuncs, sigs, external_interfaces, origcode, global_ctx, default_function
):
    sub = ["seq", func_init_lll()]
    add_gas = func_init_lll().gas

    for _def in otherfuncs:
        sub.append(
            parse_function(_def, {**{"self": sigs}, **external_interfaces}, origcode, global_ctx)
        )
        sub[-1].total_gas += add_gas
        add_gas += 30
        for sig in sig_utils.generate_default_arg_sigs(_def, external_interfaces, global_ctx):
            sig.gas = sub[-1].total_gas
            sigs[sig.sig] = sig

    # Add fallback function
    if default_function:
        default_func = parse_function(
            default_function[0], {**{"self": sigs}, **external_interfaces}, origcode, global_ctx,
        )
        fallback = default_func
    else:
        fallback = LLLnode.from_list(["revert", 0, 0], typ=None, annotation="Default function")
    sub.append(["seq_unchecked", ["label", "fallback"], fallback])
    o.append(["return", 0, ["lll", sub, 0]])
    return o, sub
Ejemplo n.º 3
0
def parse_other_functions(o, otherfuncs, sigs, external_contracts, origcode,
                          global_ctx, default_function, runtime_only):
    sub = ['seq', FUNC_INIT_LLL]
    add_gas = FUNC_INIT_LLL.gas

    for _def in otherfuncs:
        sub.append(
            parse_function(_def, {
                **{
                    'self': sigs
                },
                **external_contracts
            }, origcode, global_ctx))
        sub[-1].total_gas += add_gas
        add_gas += 30
        for sig in sig_utils.generate_default_arg_sigs(_def,
                                                       external_contracts,
                                                       global_ctx):
            sig.gas = sub[-1].total_gas
            sigs[sig.sig] = sig

    # Add fallback function
    if default_function:
        default_func = parse_function(
            default_function[0],
            {
                **{
                    'self': sigs
                },
                **external_contracts
            },
            origcode,
            global_ctx,
        )
        sub.append(default_func)
    else:
        sub.append(
            LLLnode.from_list(['revert', 0, 0],
                              typ=None,
                              annotation='Default function'))
    if runtime_only:
        return sub
    else:
        o.append(['return', 0, ['lll', sub, 0]])
        return o
Ejemplo n.º 4
0
def parse_private_function(code: ast.FunctionDef, sig: FunctionSignature,
                           context: Context) -> LLLnode:
    """
    Parse a private function (FuncDef), and produce full function body.

    :param sig: the FuntionSignature
    :param code: ast of function
    :return: full sig compare & function body
    """

    validate_private_function(code, sig)

    # Get nonreentrant lock
    nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(
        sig, context.global_ctx)

    # Create callback_ptr, this stores a destination in the bytecode for a private
    # function to jump to after a function has executed.
    clampers: List[LLLnode] = []

    # Allocate variable space.
    context.memory_allocator.increase_memory(sig.max_copy_size)

    _post_callback_ptr = f"{sig.name}_{sig.method_id}_post_callback_ptr"
    context.callback_ptr = context.new_placeholder(typ=BaseType('uint256'))
    clampers.append(
        LLLnode.from_list(
            ['mstore', context.callback_ptr, 'pass'],
            annotation='pop callback pointer',
        ))
    if sig.total_default_args > 0:
        clampers.append(LLLnode.from_list(['label', _post_callback_ptr]))

    # private functions without return types need to jump back to
    # the calling function, as there is no return statement to handle the
    # jump.
    if sig.output_type is None:
        stop_func = [['jump', ['mload', context.callback_ptr]]]
    else:
        stop_func = [['stop']]

    # Generate copiers
    if len(sig.base_args) == 0:
        copier = ['pass']
        clampers.append(LLLnode.from_list(copier))
    elif sig.total_default_args == 0:
        copier = get_private_arg_copier(
            total_size=sig.base_copy_size,
            memory_dest=MemoryPositions.RESERVED_MEMORY)
        clampers.append(LLLnode.from_list(copier))

    # Fill variable positions
    for arg in sig.args:
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos, _ = context.memory_allocator.increase_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            context.vars[arg.name] = VariableRecord(
                arg.name,
                MemoryPositions.RESERVED_MEMORY + arg.pos,
                arg.typ,
                False,
            )

    # Private function copiers. No clamping for private functions.
    dyn_variable_names = [
        a.name for a in sig.base_args if isinstance(a.typ, ByteArrayLike)
    ]
    if dyn_variable_names:
        i_placeholder = context.new_placeholder(typ=BaseType('uint256'))
        unpackers: List[Any] = []
        for idx, var_name in enumerate(dyn_variable_names):
            var = context.vars[var_name]
            ident = f"_load_args_{sig.method_id}_dynarg{idx}"
            o = make_unpacker(ident=ident,
                              i_placeholder=i_placeholder,
                              begin_pos=var.pos)
            unpackers.append(o)

        if not unpackers:
            unpackers = ['pass']

        # 0 added to complete full overarching 'seq' statement, see private_label.
        unpackers.append(0)
        clampers.append(
            LLLnode.from_list(
                ['seq_unchecked'] + unpackers,
                typ=None,
                annotation='dynamic unpacker',
                pos=getpos(code),
            ))

    # Function has default arguments.
    if sig.total_default_args > 0:  # Function with default parameters.

        default_sigs = sig_utils.generate_default_arg_sigs(
            code, context.sigs, context.global_ctx)
        sig_chain: List[Any] = ['seq']

        for default_sig in default_sigs:
            sig_compare, private_label = get_sig_statements(
                default_sig, getpos(code))

            # Populate unset default variables
            set_defaults = []
            for arg_name in get_default_names_to_set(sig, default_sig):
                value = Expr(sig.default_values[arg_name], context).lll_node
                var = context.vars[arg_name]
                left = LLLnode.from_list(var.pos,
                                         typ=var.typ,
                                         location='memory',
                                         pos=getpos(code),
                                         mutable=var.mutable)
                set_defaults.append(
                    make_setter(left, value, 'memory', pos=getpos(code)))
            current_sig_arg_names = [x.name for x in default_sig.args]

            # Load all variables in default section, if private,
            # because the stack is a linear pipe.
            copier_arg_count = len(default_sig.args)
            copier_arg_names = current_sig_arg_names

            # Order copier_arg_names, this is very important.
            copier_arg_names = [
                x.name for x in default_sig.args if x.name in copier_arg_names
            ]

            # Variables to be populated from calldata/stack.
            default_copiers: List[Any] = []
            if copier_arg_count > 0:
                # Get map of variables in calldata, with thier offsets
                offset = 4
                calldata_offset_map = {}
                for arg in default_sig.args:
                    calldata_offset_map[arg.name] = offset
                    offset += (32 if isinstance(arg.typ, ByteArrayLike) else
                               get_size_of_type(arg.typ) * 32)

                # Copy set default parameters from calldata
                dynamics = []
                for arg_name in copier_arg_names:
                    var = context.vars[arg_name]
                    if isinstance(var.typ, ByteArrayLike):
                        _size = 32
                        dynamics.append(var.pos)
                    else:
                        _size = var.size * 32
                    default_copiers.append(
                        get_private_arg_copier(
                            memory_dest=var.pos,
                            total_size=_size,
                        ))

                # Unpack byte array if necessary.
                if dynamics:
                    i_placeholder = context.new_placeholder(
                        typ=BaseType('uint256'))
                    for idx, var_pos in enumerate(dynamics):
                        ident = f'unpack_default_sig_dyn_{default_sig.method_id}_arg{idx}'
                        default_copiers.append(
                            make_unpacker(
                                ident=ident,
                                i_placeholder=i_placeholder,
                                begin_pos=var_pos,
                            ))
                default_copiers.append(0)  # for over arching seq, POP

            sig_chain.append([
                'if', sig_compare,
                [
                    'seq', private_label,
                    LLLnode.from_list([
                        'mstore',
                        context.callback_ptr,
                        'pass',
                    ],
                                      annotation='pop callback pointer',
                                      pos=getpos(code)),
                    ['seq'] + set_defaults if set_defaults else ['pass'],
                    ['seq_unchecked'] +
                    default_copiers if default_copiers else ['pass'],
                    ['goto', _post_callback_ptr]
                ]
            ])

        # With private functions all variable loading occurs in the default
        # function sub routine.
        _clampers = [['label', _post_callback_ptr]]

        # Function with default parameters.
        o = LLLnode.from_list(
            [
                'seq',
                sig_chain,
                [
                    'if',
                    0,  # can only be jumped into
                    [
                        'seq', ['seq'] + nonreentrant_pre + _clampers +
                        [parse_body(c, context)
                         for c in code.body] + nonreentrant_post + stop_func
                    ],
                ],
            ],
            typ=None,
            pos=getpos(code))

    else:
        # Function without default parameters.
        sig_compare, private_label = get_sig_statements(sig, getpos(code))
        o = LLLnode.from_list([
            'if', sig_compare, ['seq'] + [private_label] + nonreentrant_pre +
            clampers + [parse_body(c, context)
                        for c in code.body] + nonreentrant_post + stop_func
        ],
                              typ=None,
                              pos=getpos(code))
        return o

    return o
Ejemplo n.º 5
0
def parse_other_functions(o, otherfuncs, sigs, external_interfaces, global_ctx,
                          default_function):
    # check for payable/nonpayable external functions to optimize nonpayable assertions
    func_types = [i._metadata["type"] for i in global_ctx._defs]
    mutabilities = [
        i.mutability for i in func_types
        if i.visibility == FunctionVisibility.EXTERNAL
    ]
    has_payable = next(
        (True for i in mutabilities if i == StateMutability.PAYABLE), False)
    has_nonpayable = next(
        (True for i in mutabilities if i != StateMutability.PAYABLE), False)
    is_default_payable = (default_function is not None
                          and default_function._metadata["type"].mutability
                          == StateMutability.PAYABLE)
    # when a contract has a payable default function and at least one nonpayable
    # external function, we must perform the nonpayable check on every function
    check_per_function = is_default_payable and has_nonpayable

    # generate LLL for regular functions
    payable_func_sub = ["seq"]
    external_func_sub = ["seq"]
    internal_func_sub = ["seq"]
    add_gas = func_init_lll().gas

    for func_node in otherfuncs:
        func_type = func_node._metadata["type"]
        func_lll = parse_function(func_node, {
            **{
                "self": sigs
            },
            **external_interfaces
        }, global_ctx, check_per_function)
        if func_type.visibility == FunctionVisibility.INTERNAL:
            internal_func_sub.append(func_lll)
        elif func_type.mutability == StateMutability.PAYABLE:
            add_gas += 30
            payable_func_sub.append(func_lll)
        else:
            external_func_sub.append(func_lll)
            add_gas += 30
        func_lll.total_gas += add_gas
        for sig in sig_utils.generate_default_arg_sigs(func_node,
                                                       external_interfaces,
                                                       global_ctx):
            sig.gas = func_lll.total_gas
            sigs[sig.sig] = sig

    # generate LLL for fallback function
    if default_function:
        fallback_lll = parse_function(
            default_function,
            {
                **{
                    "self": sigs
                },
                **external_interfaces
            },
            global_ctx,
            # include a nonpayble check here if the contract only has a default function
            check_per_function or not otherfuncs,
        )
    else:
        fallback_lll = LLLnode.from_list(["revert", 0, 0],
                                         typ=None,
                                         annotation="Default function")

    if check_per_function:
        external_seq = ["seq", payable_func_sub, external_func_sub]
    else:
        # payable functions are placed prior to nonpayable functions
        # and seperated by a nonpayable assertion
        external_seq = ["seq"]
        if has_payable:
            external_seq.append(payable_func_sub)
        if has_nonpayable:
            external_seq.extend([["assert", ["iszero", "callvalue"]],
                                 external_func_sub])

    # bytecode is organized by: external functions, fallback fn, internal functions
    # this way we save gas and reduce bytecode by not jumping over internal functions
    main_seq = [
        "seq",
        func_init_lll(),
        ["with", "_func_sig", ["mload", 0], external_seq],
        ["seq_unchecked", ["label", "fallback"], fallback_lll],
        internal_func_sub,
    ]

    o.append(["return", 0, ["lll", main_seq, 0]])
    return o, main_seq
Ejemplo n.º 6
0
def parse_func(code, sigs, origcode, global_ctx, _vars=None):
    if _vars is None:
        _vars = {}
    sig = FunctionSignature.from_definition(
        code,
        sigs=sigs,
        custom_units=global_ctx._custom_units,
        custom_structs=global_ctx._structs,
        constants=global_ctx._constants)
    # Get base args for function.
    total_default_args = len(code.args.defaults)
    base_args = sig.args[:
                         -total_default_args] if total_default_args > 0 else sig.args
    default_args = code.args.args[-total_default_args:]
    default_values = dict(
        zip([arg.arg for arg in default_args], code.args.defaults))
    # __init__ function may not have defaults.
    if sig.name == '__init__' and total_default_args > 0:
        raise FunctionDeclarationException(
            "__init__ function may not have default parameters.")
    # Check for duplicate variables with globals
    for arg in sig.args:
        if arg.name in global_ctx._globals:
            raise FunctionDeclarationException(
                "Variable name duplicated between function arguments and globals: "
                + arg.name)

    nonreentrant_pre = [['pass']]
    nonreentrant_post = [['pass']]
    if sig.nonreentrant_key:
        nkey = global_ctx.get_nonrentrant_counter(sig.nonreentrant_key)
        nonreentrant_pre = [[
            'seq', ['assert', ['iszero', ['sload', nkey]]],
            ['sstore', nkey, 1]
        ]]
        nonreentrant_post = [['sstore', nkey, 0]]

    # Create a local (per function) context.
    memory_allocator = MemoryAllocator()
    context = Context(
        vars=_vars,
        global_ctx=global_ctx,
        sigs=sigs,
        memory_allocator=memory_allocator,
        return_type=sig.output_type,
        constancy=Constancy.Constant if sig.const else Constancy.Mutable,
        is_payable=sig.payable,
        origcode=origcode,
        is_private=sig.private,
        method_id=sig.method_id)

    # Copy calldata to memory for fixed-size arguments
    max_copy_size = sum([
        32 if isinstance(arg.typ, ByteArrayLike) else
        get_size_of_type(arg.typ) * 32 for arg in sig.args
    ])
    base_copy_size = sum([
        32 if isinstance(arg.typ, ByteArrayLike) else
        get_size_of_type(arg.typ) * 32 for arg in base_args
    ])
    # context.next_mem += max_copy_size
    context.memory_allocator.increase_memory(max_copy_size)

    clampers = []

    # Create callback_ptr, this stores a destination in the bytecode for a private
    # function to jump to after a function has executed.
    _post_callback_ptr = "{}_{}_post_callback_ptr".format(
        sig.name, sig.method_id)
    if sig.private:
        context.callback_ptr = context.new_placeholder(typ=BaseType('uint256'))
        clampers.append(
            LLLnode.from_list(
                ['mstore', context.callback_ptr, 'pass'],
                annotation='pop callback pointer',
            ))
        if total_default_args > 0:
            clampers.append(['label', _post_callback_ptr])

    # private functions without return types need to jump back to
    # the calling function, as there is no return statement to handle the
    # jump.
    stop_func = [['stop']]
    if sig.output_type is None and sig.private:
        stop_func = [['jump', ['mload', context.callback_ptr]]]

    if not len(base_args):
        copier = 'pass'
    elif sig.name == '__init__':
        copier = [
            'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen',
            base_copy_size
        ]
    else:
        copier = get_arg_copier(sig=sig,
                                total_size=base_copy_size,
                                memory_dest=MemoryPositions.RESERVED_MEMORY)
    clampers.append(copier)

    # Add asserts for payable and internal
    # private never gets payable check.
    if not sig.payable and not sig.private:
        clampers.append(['assert', ['iszero', 'callvalue']])

    # Fill variable positions
    for i, arg in enumerate(sig.args):
        if i < len(base_args) and not sig.private:

            clampers.append(
                make_clamper(
                    arg.pos,
                    context.memory_allocator.get_next_memory_position(),
                    arg.typ,
                    sig.name == '__init__',
                ))
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos, _ = context.memory_allocator.increase_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            context.vars[arg.name] = VariableRecord(
                arg.name,
                MemoryPositions.RESERVED_MEMORY + arg.pos,
                arg.typ,
                False,
            )

    # Private function copiers. No clamping for private functions.
    dyn_variable_names = [
        a.name for a in base_args if isinstance(a.typ, ByteArrayLike)
    ]
    if sig.private and dyn_variable_names:
        i_placeholder = context.new_placeholder(typ=BaseType('uint256'))
        unpackers = []
        for idx, var_name in enumerate(dyn_variable_names):
            var = context.vars[var_name]
            ident = "_load_args_%d_dynarg%d" % (sig.method_id, idx)
            o = make_unpacker(ident=ident,
                              i_placeholder=i_placeholder,
                              begin_pos=var.pos)
            unpackers.append(o)

        if not unpackers:
            unpackers = ['pass']

        clampers.append(
            LLLnode.from_list(
                # [0] to complete full overarching 'seq' statement, see private_label.
                ['seq_unchecked'] + unpackers + [0],
                typ=None,
                annotation='dynamic unpacker',
                pos=getpos(code),
            ))

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == '__init__':
        o = LLLnode.from_list(
            ['seq'] + clampers + [parse_body(code.body, context)],
            pos=getpos(code),
        )
    elif is_default_func(sig):
        if len(sig.args) > 0:
            raise FunctionDeclarationException(
                'Default function may not receive any arguments.', code)
        if sig.private:
            raise FunctionDeclarationException(
                'Default function may only be public.',
                code,
            )
        o = LLLnode.from_list(
            ['seq'] + clampers + [parse_body(code.body, context)],
            pos=getpos(code),
        )
    else:

        if total_default_args > 0:  # Function with default parameters.
            function_routine = "{}_{}".format(sig.name, sig.method_id)
            default_sigs = sig_utils.generate_default_arg_sigs(
                code, sigs, global_ctx)
            sig_chain = ['seq']

            for default_sig in default_sigs:
                sig_compare, private_label = get_sig_statements(
                    default_sig, getpos(code))

                # Populate unset default variables
                populate_arg_count = len(sig.args) - len(default_sig.args)
                set_defaults = []
                if populate_arg_count > 0:
                    current_sig_arg_names = {x.name for x in default_sig.args}
                    missing_arg_names = [
                        arg.arg for arg in default_args
                        if arg.arg not in current_sig_arg_names
                    ]
                    for arg_name in missing_arg_names:
                        value = Expr(default_values[arg_name],
                                     context).lll_node
                        var = context.vars[arg_name]
                        left = LLLnode.from_list(var.pos,
                                                 typ=var.typ,
                                                 location='memory',
                                                 pos=getpos(code),
                                                 mutable=var.mutable)
                        set_defaults.append(
                            make_setter(left,
                                        value,
                                        'memory',
                                        pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in base_args}
                if sig.private:
                    # Load all variables in default section, if private,
                    # because the stack is a linear pipe.
                    copier_arg_count = len(default_sig.args)
                    copier_arg_names = current_sig_arg_names
                else:
                    copier_arg_count = len(default_sig.args) - len(base_args)
                    copier_arg_names = current_sig_arg_names - base_arg_names
                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)
                    # Copy set default parameters from calldata
                    dynamics = []
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]
                        if sig.private:
                            _offset = calldata_offset
                            if isinstance(var.typ, ByteArrayLike):
                                _size = 32
                                dynamics.append(var.pos)
                            else:
                                _size = var.size * 32
                            default_copiers.append(
                                get_arg_copier(
                                    sig=sig,
                                    memory_dest=var.pos,
                                    total_size=_size,
                                    offset=_offset,
                                ))
                        else:
                            # Add clampers.
                            default_copiers.append(
                                make_clamper(
                                    calldata_offset - 4,
                                    var.pos,
                                    var.typ,
                                ))
                            # Add copying code.
                            if isinstance(var.typ, ByteArrayLike):
                                _offset = [
                                    'add', 4,
                                    ['calldataload', calldata_offset]
                                ]
                            else:
                                _offset = calldata_offset
                            default_copiers.append(
                                get_arg_copier(
                                    sig=sig,
                                    memory_dest=var.pos,
                                    total_size=var.size * 32,
                                    offset=_offset,
                                ))

                    # Unpack byte array if necessary.
                    if dynamics:
                        i_placeholder = context.new_placeholder(
                            typ=BaseType('uint256'))
                        for idx, var_pos in enumerate(dynamics):
                            ident = 'unpack_default_sig_dyn_%d_arg%d' % (
                                default_sig.method_id, idx)
                            default_copiers.append(
                                make_unpacker(
                                    ident=ident,
                                    i_placeholder=i_placeholder,
                                    begin_pos=var_pos,
                                ))
                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    'if', sig_compare,
                    [
                        'seq', private_label, ['pass'] if not sig.private else
                        LLLnode.from_list([
                            'mstore',
                            context.callback_ptr,
                            'pass',
                        ],
                                          annotation='pop callback pointer',
                                          pos=getpos(code)),
                        ['seq'] + set_defaults if set_defaults else ['pass'],
                        ['seq_unchecked'] +
                        default_copiers if default_copiers else ['pass'],
                        [
                            'goto', _post_callback_ptr
                            if sig.private else function_routine
                        ]
                    ]
                ])

            # With private functions all variable loading occurs in the default
            # function sub routine.
            if sig.private:
                _clampers = [['label', _post_callback_ptr]]
            else:
                _clampers = clampers

            # Function with default parameters.
            o = LLLnode.from_list(
                [
                    'seq',
                    sig_chain,
                    [
                        'if',
                        0,  # can only be jumped into
                        [
                            'seq', ['label', function_routine]
                            if not sig.private else ['pass'],
                            ['seq'] + nonreentrant_pre + _clampers +
                            [parse_body(c, context) for c in code.body] +
                            nonreentrant_post + stop_func
                        ],
                    ],
                ],
                typ=None,
                pos=getpos(code))

        else:
            # Function without default parameters.
            sig_compare, private_label = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list([
                'if', sig_compare,
                ['seq'] + [private_label] + nonreentrant_pre + clampers +
                [parse_body(c, context)
                 for c in code.body] + nonreentrant_post + stop_func
            ],
                                  typ=None,
                                  pos=getpos(code))

    # Check for at leasts one return statement if necessary.
    if context.return_type and context.function_return_count == 0:
        raise FunctionDeclarationException(
            "Missing return statement in function '%s' " % sig.name, code)

    o.context = context
    o.total_gas = o.gas + calc_mem_gas(
        o.context.memory_allocator.get_next_memory_position())
    o.func_name = sig.name
    return o
def parse_external_function(
    code: vy_ast.FunctionDef,
    sig: FunctionSignature,
    context: Context,
    check_nonpayable: bool,
) -> LLLnode:
    """
    Parse a external function (FuncDef), and produce full function body.

    :param sig: the FuntionSignature
    :param code: ast of function
    :param check_nonpayable: if True, include a check that `msg.value == 0`
                             at the beginning of the function
    :return: full sig compare & function body
    """

    func_type = code._metadata["type"]

    # Get nonreentrant lock
    nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(
        func_type, context.global_ctx)

    clampers = []

    # Generate copiers
    copier: List[Any] = ["pass"]
    if not len(sig.base_args):
        copier = ["pass"]
    elif sig.name == "__init__":
        copier = [
            "codecopy", MemoryPositions.RESERVED_MEMORY, "~codelen",
            sig.base_copy_size
        ]
        context.memory_allocator.expand_memory(sig.max_copy_size)
    clampers.append(copier)

    if check_nonpayable and sig.mutability != "payable":
        # if the contract contains payable functions, but this is not one of them
        # add an assertion that the value of the call is zero
        clampers.append(["assert", ["iszero", "callvalue"]])

    # Fill variable positions
    default_args_start_pos = len(sig.base_args)
    for i, arg in enumerate(sig.args):
        if i < len(sig.base_args):
            clampers.append(
                make_arg_clamper(
                    arg.pos,
                    context.memory_allocator.get_next_memory_position(),
                    arg.typ,
                    sig.name == "__init__",
                ))
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos = context.memory_allocator.expand_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            if sig.name == "__init__":
                context.vars[arg.name] = VariableRecord(
                    arg.name,
                    MemoryPositions.RESERVED_MEMORY + arg.pos,
                    arg.typ,
                    False,
                )
            elif i >= default_args_start_pos:  # default args need to be allocated in memory.
                type_size = get_size_of_type(arg.typ) * 32
                default_arg_pos = context.memory_allocator.expand_memory(
                    type_size)
                context.vars[arg.name] = VariableRecord(
                    name=arg.name,
                    pos=default_arg_pos,
                    typ=arg.typ,
                    mutable=False,
                )
            else:
                context.vars[arg.name] = VariableRecord(name=arg.name,
                                                        pos=4 + arg.pos,
                                                        typ=arg.typ,
                                                        mutable=False,
                                                        location="calldata")

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == "__init__":
        o = LLLnode.from_list(
            ["seq"] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is default function.
    elif sig.is_default_func():
        o = LLLnode.from_list(
            ["seq"] + clampers + [parse_body(code.body, context)] +
            [["stop"]],  # type: ignore
            pos=getpos(code),
        )
    # Is a normal function.
    else:
        # Function with default parameters.
        if sig.total_default_args > 0:
            function_routine = f"{sig.name}_{sig.method_id}"
            default_sigs = sig_utils.generate_default_arg_sigs(
                code, context.sigs, context.global_ctx)
            sig_chain: List[Any] = ["seq"]

            for default_sig in default_sigs:
                sig_compare, _ = get_sig_statements(default_sig, getpos(code))

                # Populate unset default variables
                set_defaults = []
                for arg_name in get_default_names_to_set(sig, default_sig):
                    value = Expr(sig.default_values[arg_name],
                                 context).lll_node
                    var = context.vars[arg_name]
                    left = LLLnode.from_list(
                        var.pos,
                        typ=var.typ,
                        location="memory",
                        pos=getpos(code),
                        mutable=var.mutable,
                    )
                    set_defaults.append(
                        make_setter(left, value, "memory", pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in sig.base_args}
                copier_arg_count = len(default_sig.args) - len(sig.base_args)
                copier_arg_names = list(current_sig_arg_names - base_arg_names)

                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers: List[Any] = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)

                    # Copy default parameters from calldata.
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]

                        # Add clampers.
                        default_copiers.append(
                            make_arg_clamper(
                                calldata_offset - 4,
                                var.pos,
                                var.typ,
                            ))
                        # Add copying code.
                        _offset: Union[int, List[Any]] = calldata_offset
                        if isinstance(var.typ, ByteArrayLike):
                            _offset = [
                                "add", 4, ["calldataload", calldata_offset]
                            ]
                        default_copiers.append(
                            get_external_arg_copier(
                                memory_dest=var.pos,
                                total_size=var.size * 32,
                                offset=_offset,
                            ))

                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    "if",
                    sig_compare,
                    [
                        "seq",
                        ["seq"] + set_defaults if set_defaults else ["pass"],
                        ["seq_unchecked"] +
                        default_copiers if default_copiers else ["pass"],
                        ["goto", function_routine],
                    ],
                ])

            # Function with default parameters.
            function_jump_label = f"{sig.name}_{sig.method_id}_skip"
            o = LLLnode.from_list(
                [
                    "seq",
                    sig_chain,
                    [
                        "seq",
                        ["goto", function_jump_label],
                        ["label", function_routine],
                        ["seq"] + nonreentrant_pre + clampers +
                        [parse_body(c, context)
                         for c in code.body] + nonreentrant_post + [["stop"]],
                        ["label", function_jump_label],
                    ],
                ],
                typ=None,
                pos=getpos(code),
            )

        else:
            # Function without default parameters.
            sig_compare, _ = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list(
                [
                    "if",
                    sig_compare,
                    ["seq"] + nonreentrant_pre + clampers +
                    [parse_body(c, context)
                     for c in code.body] + nonreentrant_post + [["stop"]],
                ],
                typ=None,
                pos=getpos(code),
            )
    return o
Ejemplo n.º 8
0
def parse_public_function(code: ast.FunctionDef, sig: FunctionSignature,
                          context: Context) -> LLLnode:
    """
    Parse a public function (FuncDef), and produce full function body.

    :param sig: the FuntionSignature
    :param code: ast of function
    :return: full sig compare & function body
    """

    validate_public_function(code, sig, context.global_ctx)

    # Get nonreentrant lock
    nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(
        sig, context.global_ctx)

    clampers = []

    # Generate copiers
    copier: List[Any] = ['pass']
    if not len(sig.base_args):
        copier = ['pass']
    elif sig.name == '__init__':
        copier = [
            'codecopy', MemoryPositions.RESERVED_MEMORY, '~codelen',
            sig.base_copy_size
        ]
        context.memory_allocator.increase_memory(sig.max_copy_size)
    clampers.append(copier)

    # Add asserts for payable and internal
    if not sig.payable:
        clampers.append(['assert', ['iszero', 'callvalue']])

    # Fill variable positions
    default_args_start_pos = len(sig.base_args)
    for i, arg in enumerate(sig.args):
        if i < len(sig.base_args):
            clampers.append(
                make_arg_clamper(
                    arg.pos,
                    context.memory_allocator.get_next_memory_position(),
                    arg.typ,
                    sig.name == '__init__',
                ))
        if isinstance(arg.typ, ByteArrayLike):
            mem_pos, _ = context.memory_allocator.increase_memory(
                32 * get_size_of_type(arg.typ))
            context.vars[arg.name] = VariableRecord(arg.name, mem_pos, arg.typ,
                                                    False)
        else:
            if sig.name == '__init__':
                context.vars[arg.name] = VariableRecord(
                    arg.name,
                    MemoryPositions.RESERVED_MEMORY + arg.pos,
                    arg.typ,
                    False,
                )
            elif i >= default_args_start_pos:  # default args need to be allocated in memory.
                default_arg_pos, _ = context.memory_allocator.increase_memory(
                    32)
                context.vars[arg.name] = VariableRecord(
                    name=arg.name,
                    pos=default_arg_pos,
                    typ=arg.typ,
                    mutable=False,
                )
            else:
                context.vars[arg.name] = VariableRecord(name=arg.name,
                                                        pos=4 + arg.pos,
                                                        typ=arg.typ,
                                                        mutable=False,
                                                        location='calldata')

    # Create "clampers" (input well-formedness checkers)
    # Return function body
    if sig.name == '__init__':
        o = LLLnode.from_list(
            ['seq'] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is default function.
    elif sig.is_default_func():
        if len(sig.args) > 0:
            raise FunctionDeclarationException(
                'Default function may not receive any arguments.', code)
        o = LLLnode.from_list(
            ['seq'] + clampers +
            [parse_body(code.body, context)],  # type: ignore
            pos=getpos(code),
        )
    # Is a normal function.
    else:
        # Function with default parameters.
        if sig.total_default_args > 0:
            function_routine = f"{sig.name}_{sig.method_id}"
            default_sigs = sig_utils.generate_default_arg_sigs(
                code, context.sigs, context.global_ctx)
            sig_chain: List[Any] = ['seq']

            for default_sig in default_sigs:
                sig_compare, _ = get_sig_statements(default_sig, getpos(code))

                # Populate unset default variables
                set_defaults = []
                for arg_name in get_default_names_to_set(sig, default_sig):
                    value = Expr(sig.default_values[arg_name],
                                 context).lll_node
                    var = context.vars[arg_name]
                    left = LLLnode.from_list(var.pos,
                                             typ=var.typ,
                                             location='memory',
                                             pos=getpos(code),
                                             mutable=var.mutable)
                    set_defaults.append(
                        make_setter(left, value, 'memory', pos=getpos(code)))

                current_sig_arg_names = {x.name for x in default_sig.args}
                base_arg_names = {arg.name for arg in sig.base_args}
                copier_arg_count = len(default_sig.args) - len(sig.base_args)
                copier_arg_names = list(current_sig_arg_names - base_arg_names)

                # Order copier_arg_names, this is very important.
                copier_arg_names = [
                    x.name for x in default_sig.args
                    if x.name in copier_arg_names
                ]

                # Variables to be populated from calldata/stack.
                default_copiers: List[Any] = []
                if copier_arg_count > 0:
                    # Get map of variables in calldata, with thier offsets
                    offset = 4
                    calldata_offset_map = {}
                    for arg in default_sig.args:
                        calldata_offset_map[arg.name] = offset
                        offset += (32 if isinstance(arg.typ, ByteArrayLike)
                                   else get_size_of_type(arg.typ) * 32)

                    # Copy default parameters from calldata.
                    for arg_name in copier_arg_names:
                        var = context.vars[arg_name]
                        calldata_offset = calldata_offset_map[arg_name]

                        # Add clampers.
                        default_copiers.append(
                            make_arg_clamper(
                                calldata_offset - 4,
                                var.pos,
                                var.typ,
                            ))
                        # Add copying code.
                        _offset: Union[int, List[Any]] = calldata_offset
                        if isinstance(var.typ, ByteArrayLike):
                            _offset = [
                                'add', 4, ['calldataload', calldata_offset]
                            ]
                        default_copiers.append(
                            get_public_arg_copier(
                                memory_dest=var.pos,
                                total_size=var.size * 32,
                                offset=_offset,
                            ))

                    default_copiers.append(0)  # for over arching seq, POP

                sig_chain.append([
                    'if', sig_compare,
                    [
                        'seq',
                        ['seq'] + set_defaults if set_defaults else ['pass'],
                        ['seq_unchecked'] +
                        default_copiers if default_copiers else ['pass'],
                        ['goto', function_routine]
                    ]
                ])

            # Function with default parameters.
            o = LLLnode.from_list(
                [
                    'seq',
                    sig_chain,
                    [
                        'if',
                        0,  # can only be jumped into
                        [
                            'seq', ['label', function_routine
                                    ], ['seq'] + nonreentrant_pre + clampers +
                            [parse_body(c, context) for c in code.body] +
                            nonreentrant_post + [['stop']]
                        ],
                    ],
                ],
                typ=None,
                pos=getpos(code))

        else:
            # Function without default parameters.
            sig_compare, _ = get_sig_statements(sig, getpos(code))
            o = LLLnode.from_list([
                'if', sig_compare, ['seq'] + nonreentrant_pre + clampers +
                [parse_body(c, context)
                 for c in code.body] + nonreentrant_post + [['stop']]
            ],
                                  typ=None,
                                  pos=getpos(code))
    return o