예제 #1
0
    def handler_for(calldata_kwargs, default_kwargs):
        calldata_args = sig.base_args + calldata_kwargs
        # create a fake type so that get_element_ptr works
        calldata_args_t = TupleType(list(arg.typ for arg in calldata_args))

        abi_sig = sig.abi_signature_for_kwargs(calldata_kwargs)
        method_id = _annotated_method_id(abi_sig)

        calldata_kwargs_ofst = IRnode(
            4, location=CALLDATA, typ=calldata_args_t, encoding=Encoding.ABI
        )

        # a sequence of statements to strictify kwargs into memory
        ret = ["seq"]

        # ensure calldata is at least of minimum length
        args_abi_t = calldata_args_t.abi_type
        calldata_min_size = args_abi_t.min_size() + 4
        if args_abi_t.is_dynamic():
            ret.append(["assert", ["ge", "calldatasize", calldata_min_size]])
        else:
            # stricter for static data
            ret.append(["assert", ["eq", "calldatasize", calldata_min_size]])

        # TODO optimize make_setter by using
        # TupleType(list(arg.typ for arg in calldata_kwargs + default_kwargs))
        # (must ensure memory area is contiguous)

        n_base_args = len(sig.base_args)

        for i, arg_meta in enumerate(calldata_kwargs):
            k = n_base_args + i

            dst = context.lookup_var(arg_meta.name).pos

            lhs = IRnode(dst, location=MEMORY, typ=arg_meta.typ)

            rhs = get_element_ptr(calldata_kwargs_ofst, k, array_bounds_check=False)

            copy_arg = make_setter(lhs, rhs)
            copy_arg.source_pos = getpos(arg_meta.ast_source)
            ret.append(copy_arg)

        for x in default_kwargs:
            dst = context.lookup_var(x.name).pos
            lhs = IRnode(dst, location=MEMORY, typ=x.typ)
            lhs.source_pos = getpos(x.ast_source)
            kw_ast_val = sig.default_values[x.name]  # e.g. `3` in x: int = 3
            rhs = Expr(kw_ast_val, context).ir_node

            copy_arg = make_setter(lhs, rhs)
            copy_arg.source_pos = getpos(x.ast_source)
            ret.append(copy_arg)

        ret.append(["goto", sig.external_function_base_entry_label])

        ret = ["if", ["eq", "_calldata_method_id", method_id], ret]
        return ret
예제 #2
0
파일: stmt.py 프로젝트: skellet0r/vyper
    def _parse_For_list(self):
        with self.context.range_scope():
            iter_list = Expr(self.stmt.iter, self.context).lll_node

        # override with type inferred at typechecking time
        # TODO investigate why stmt.target.type != stmt.iter.type.subtype
        target_type = new_type_to_old_type(self.stmt.target._metadata["type"])
        iter_list.typ.subtype = target_type

        # user-supplied name for loop variable
        varname = self.stmt.target.id
        loop_var = LLLnode.from_list(
            self.context.new_variable(varname, target_type),
            typ=target_type,
            location="memory",
        )

        i = LLLnode.from_list(self.context.fresh_varname("for_list_ix"),
                              typ="uint256")

        self.context.forvars[varname] = True

        ret = ["seq"]

        # list literal, force it to memory first
        if isinstance(self.stmt.iter, vy_ast.List):
            tmp_list = LLLnode.from_list(
                self.context.new_internal_variable(iter_list.typ),
                typ=iter_list.typ,
                location="memory",
            )
            ret.append(make_setter(tmp_list, iter_list, pos=getpos(self.stmt)))
            iter_list = tmp_list

        # set up the loop variable
        loop_var_ast = getpos(self.stmt.target)
        e = get_element_ptr(iter_list,
                            i,
                            array_bounds_check=False,
                            pos=loop_var_ast)
        body = [
            "seq",
            make_setter(loop_var, e, pos=loop_var_ast),
            parse_body(self.stmt.body, self.context),
        ]

        repeat_bound = iter_list.typ.count
        if isinstance(iter_list.typ, DArrayType):
            array_len = get_dyn_array_count(iter_list)
        else:
            array_len = repeat_bound

        ret.append(["repeat", i, 0, array_len, repeat_bound, body])

        del self.context.forvars[varname]
        return LLLnode.from_list(ret, pos=getpos(self.stmt))
예제 #3
0
def _register_function_args(context: Context, sig: FunctionSignature) -> List[LLLnode]:
    pos = None

    ret = []

    # the type of the calldata
    base_args_t = TupleType([arg.typ for arg in sig.base_args])

    # tuple with the abi_encoded args
    if sig.is_init_func:
        base_args_ofst = LLLnode(0, location="data", typ=base_args_t, encoding=Encoding.ABI)
    else:
        base_args_ofst = LLLnode(4, location="calldata", typ=base_args_t, encoding=Encoding.ABI)

    for i, arg in enumerate(sig.base_args):

        arg_lll = get_element_ptr(base_args_ofst, i, pos=pos)

        if _should_decode(arg.typ):
            # allocate a memory slot for it and copy
            p = context.new_variable(arg.name, arg.typ, is_mutable=False)
            dst = LLLnode(p, typ=arg.typ, location="memory")
            ret.append(make_setter(dst, arg_lll, pos=pos))
        else:
            # leave it in place
            context.vars[arg.name] = VariableRecord(
                name=arg.name,
                pos=arg_lll,
                typ=arg.typ,
                mutable=False,
                location=arg_lll.location,
                encoding=Encoding.ABI,
            )

    return ret
예제 #4
0
파일: stmt.py 프로젝트: ProGamerCode/vyper
    def parse_Assign(self):
        # Assignment (e.g. x[4] = y)
        sub = Expr(self.stmt.value, self.context).ir_node
        target = self._get_target(self.stmt.target)

        ir_node = make_setter(target, sub)
        return ir_node
예제 #5
0
파일: stmt.py 프로젝트: ProGamerCode/vyper
    def parse_AnnAssign(self):
        typ = parse_type(
            self.stmt.annotation,
            sigs=self.context.sigs,
            custom_structs=self.context.structs,
        )
        varname = self.stmt.target.id
        pos = self.context.new_variable(varname, typ)
        if self.stmt.value is None:
            return

        sub = Expr(self.stmt.value, self.context).ir_node

        is_literal_bytes32_assign = (isinstance(sub.typ, ByteArrayType)
                                     and sub.typ.maxlen == 32
                                     and isinstance(typ, BaseType)
                                     and typ.typ == "bytes32"
                                     and sub.typ.is_literal)

        # If bytes[32] to bytes32 assignment rewrite sub as bytes32.
        if is_literal_bytes32_assign:
            sub = IRnode(
                util.bytes_to_int(self.stmt.value.s),
                typ=BaseType("bytes32"),
            )

        variable_loc = IRnode.from_list(pos, typ=typ, location=MEMORY)

        ir_node = make_setter(variable_loc, sub)

        return ir_node
예제 #6
0
파일: stmt.py 프로젝트: skellet0r/vyper
    def parse_Assign(self):
        # Assignment (e.g. x[4] = y)
        sub = Expr(self.stmt.value, self.context).lll_node
        target = self._get_target(self.stmt.target)

        lll_node = make_setter(target, sub, pos=getpos(self.stmt))
        lll_node.pos = getpos(self.stmt)
        return lll_node
예제 #7
0
    def handler_for(calldata_kwargs, default_kwargs):
        calldata_args = sig.base_args + calldata_kwargs
        # create a fake type so that get_element_ptr works
        calldata_args_t = TupleType(list(arg.typ for arg in calldata_args))

        abi_sig = sig.abi_signature_for_kwargs(calldata_kwargs)
        method_id = _annotated_method_id(abi_sig)

        calldata_kwargs_ofst = LLLnode(4,
                                       location="calldata",
                                       typ=calldata_args_t,
                                       encoding=Encoding.ABI)

        # a sequence of statements to strictify kwargs into memory
        ret = ["seq"]

        # TODO optimize make_setter by using
        # TupleType(list(arg.typ for arg in calldata_kwargs + default_kwargs))
        # (must ensure memory area is contiguous)

        n_base_args = len(sig.base_args)

        for i, arg_meta in enumerate(calldata_kwargs):
            k = n_base_args + i

            dst = context.lookup_var(arg_meta.name).pos

            lhs = LLLnode(dst, location="memory", typ=arg_meta.typ)
            rhs = get_element_ptr(calldata_kwargs_ofst,
                                  k,
                                  pos=None,
                                  array_bounds_check=False)
            ret.append(make_setter(lhs, rhs, pos))

        for x in default_kwargs:
            dst = context.lookup_var(x.name).pos
            lhs = LLLnode(dst, location="memory", typ=x.typ)
            kw_ast_val = sig.default_values[x.name]  # e.g. `3` in x: int = 3
            rhs = Expr(kw_ast_val, context).lll_node
            ret.append(make_setter(lhs, rhs, pos))

        ret.append(["goto", sig.external_function_base_entry_label])

        ret = ["if", ["eq", "_calldata_method_id", method_id], ret]
        return ret
예제 #8
0
def _encode_dyn_array_helper(dst, ir_node, context):
    # if it's a literal, first serialize to memory as we
    # don't have a compile-time abi encoder
    # TODO handle this upstream somewhere
    if ir_node.value == "multi":
        buf = context.new_internal_variable(dst.typ)
        buf = IRnode.from_list(buf, typ=dst.typ, location=MEMORY)
        _bufsz = dst.typ.abi_type.size_bound()
        return [
            "seq",
            make_setter(buf, ir_node),
            [
                "set", "dyn_ofst",
                abi_encode(dst, buf, context, _bufsz, returns_len=True)
            ],
        ]

    subtyp = ir_node.typ.subtype
    child_abi_t = subtyp.abi_type

    ret = ["seq"]

    len_ = get_dyn_array_count(ir_node)
    with len_.cache_when_complex("len") as (b, len_):
        # set the length word
        ret.append(STORE(dst, len_))

        # prepare the loop
        t = BaseType("uint256")
        i = IRnode.from_list(context.fresh_varname("ix"), typ=t)

        # offset of the i'th element in ir_node
        child_location = get_element_ptr(ir_node, i, array_bounds_check=False)

        # offset of the i'th element in dst
        dst = add_ofst(dst, 32)  # jump past length word
        static_elem_size = child_abi_t.embedded_static_size()
        static_ofst = ["mul", i, static_elem_size]
        loop_body = _encode_child_helper(dst, child_location, static_ofst,
                                         "dyn_child_ofst", context)
        loop = ["repeat", i, 0, len_, ir_node.typ.count, loop_body]

        x = ["seq", loop, "dyn_child_ofst"]
        start_dyn_ofst = ["mul", len_, static_elem_size]
        run_children = ["with", "dyn_child_ofst", start_dyn_ofst, x]
        new_dyn_ofst = ["add", "dyn_ofst", run_children]
        # size of dynarray is size of encoded children + size of the length word
        # TODO optimize by adding 32 to the initial value of dyn_ofst
        new_dyn_ofst = ["add", 32, new_dyn_ofst]
        ret.append(["set", "dyn_ofst", new_dyn_ofst])

        return b.resolve(ret)
예제 #9
0
def _register_function_args(context: Context,
                            sig: FunctionSignature) -> List[IRnode]:
    ret = []

    # the type of the calldata
    base_args_t = TupleType([arg.typ for arg in sig.base_args])

    # tuple with the abi_encoded args
    if sig.is_init_func:
        base_args_ofst = IRnode(0,
                                location=DATA,
                                typ=base_args_t,
                                encoding=Encoding.ABI)
    else:
        base_args_ofst = IRnode(4,
                                location=CALLDATA,
                                typ=base_args_t,
                                encoding=Encoding.ABI)

    for i, arg in enumerate(sig.base_args):

        arg_ir = get_element_ptr(base_args_ofst, i)

        if _should_decode(arg.typ):
            # allocate a memory slot for it and copy
            p = context.new_variable(arg.name, arg.typ, is_mutable=False)
            dst = IRnode(p, typ=arg.typ, location=MEMORY)

            copy_arg = make_setter(dst, arg_ir)
            copy_arg.source_pos = getpos(arg.ast_source)
            ret.append(copy_arg)
        else:
            # leave it in place
            context.vars[arg.name] = VariableRecord(
                name=arg.name,
                pos=arg_ir,
                typ=arg.typ,
                mutable=False,
                location=arg_ir.location,
                encoding=Encoding.ABI,
            )

    return ret
예제 #10
0
def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context,
                       expr):
    ast_return_t = fn_type.return_type

    if ast_return_t is None:
        return ["pass"], 0, 0

    return_t = new_type_to_old_type(ast_return_t)

    wrapped_return_t = calculate_type_for_external_return(return_t)

    abi_return_t = wrapped_return_t.abi_type

    min_return_size = abi_return_t.min_size()
    max_return_size = abi_return_t.size_bound()
    assert 0 < min_return_size <= max_return_size

    ret_ofst = buf
    ret_len = max_return_size

    encoding = Encoding.ABI

    buf = IRnode.from_list(
        buf,
        typ=wrapped_return_t,
        location=MEMORY,
        encoding=encoding,
        annotation=f"{expr.node_source_code} returndata buffer",
    )

    unpacker = ["seq"]

    # revert when returndatasize is not in bounds
    # (except when return_override is provided.)
    if not call_kwargs.skip_contract_check:
        unpacker.append(["assert", ["ge", "returndatasize", min_return_size]])

    assert isinstance(wrapped_return_t, TupleType)

    # unpack strictly
    if needs_clamp(wrapped_return_t, encoding):
        return_buf = context.new_internal_variable(wrapped_return_t)
        return_buf = IRnode.from_list(return_buf,
                                      typ=wrapped_return_t,
                                      location=MEMORY)

        # note: make_setter does ABI decoding and clamps
        unpacker.append(make_setter(return_buf, buf))
    else:
        return_buf = buf

    if call_kwargs.default_return_value is not None:
        # if returndatasize == 0:
        #    copy return override to buf
        # else:
        #    do the other stuff

        override_value = wrap_value_for_external_return(
            call_kwargs.default_return_value)
        stomp_return_buffer = ["seq"]
        if not call_kwargs.skip_contract_check:
            stomp_return_buffer.append(_extcodesize_check(contract_address))
        stomp_return_buffer.append(make_setter(return_buf, override_value))
        unpacker = [
            "if", ["eq", "returndatasize", 0], stomp_return_buffer, unpacker
        ]

    unpacker = ["seq", unpacker, return_buf]

    return unpacker, ret_ofst, ret_len
예제 #11
0
def lll_for_self_call(stmt_expr, context):
    from vyper.codegen.expr import Expr  # TODO rethink this circular import

    pos = getpos(stmt_expr)

    # ** Internal Call **
    # Steps:
    # - copy arguments into the soon-to-be callee
    # - allocate return buffer
    # - push jumpdest (callback ptr) and return buffer location
    # - jump to label
    # - (private function will fill return buffer and jump back)

    method_name = stmt_expr.func.attr

    pos_args_lll = [Expr(x, context).lll_node for x in stmt_expr.args]

    sig, kw_vals = context.lookup_internal_function(method_name, pos_args_lll)

    kw_args_lll = [Expr(x, context).lll_node for x in kw_vals]

    args_lll = pos_args_lll + kw_args_lll

    args_tuple_t = TupleType([x.typ for x in args_lll])
    args_as_tuple = LLLnode.from_list(["multi"] + [x for x in args_lll], typ=args_tuple_t)

    # register callee to help calculate our starting frame offset
    context.register_callee(sig.frame_size)

    if context.is_constant() and sig.mutability not in ("view", "pure"):
        raise StateAccessViolation(
            f"May not call state modifying function "
            f"'{method_name}' within {context.pp_constancy()}.",
            getpos(stmt_expr),
        )

    # TODO move me to type checker phase
    if not sig.internal:
        raise StructureException("Cannot call external functions via 'self'", stmt_expr)

    return_label = _generate_label(f"{sig.internal_function_label}_call")

    # allocate space for the return buffer
    # TODO allocate in stmt and/or expr.py
    if sig.return_type is not None:
        return_buffer = LLLnode.from_list(
            context.new_internal_variable(sig.return_type), annotation=f"{return_label}_return_buf"
        )
    else:
        return_buffer = None

    # note: dst_tuple_t != args_tuple_t
    dst_tuple_t = TupleType([arg.typ for arg in sig.args])
    args_dst = LLLnode(sig.frame_start, typ=dst_tuple_t, location="memory")

    # if one of the arguments is a self call, the argument
    # buffer could get borked. to prevent against that,
    # write args to a temporary buffer until all the arguments
    # are fully evaluated.
    if args_as_tuple.contains_self_call:
        copy_args = ["seq"]
        # TODO deallocate me
        tmp_args_buf = LLLnode(
            context.new_internal_variable(dst_tuple_t),
            typ=dst_tuple_t,
            location="memory",
        )
        copy_args.append(
            # --> args evaluate here <--
            make_setter(tmp_args_buf, args_as_tuple, pos)
        )

        copy_args.append(make_setter(args_dst, tmp_args_buf, pos))

    else:
        copy_args = make_setter(args_dst, args_as_tuple, pos)

    goto_op = ["goto", sig.internal_function_label]
    # pass return buffer to subroutine
    if return_buffer is not None:
        goto_op += [return_buffer]
    # pass return label to subroutine
    goto_op += [push_label_to_stack(return_label)]

    call_sequence = [
        "seq",
        copy_args,
        goto_op,
        ["label", return_label, ["var_list"], "pass"],
    ]
    if return_buffer is not None:
        # push return buffer location to stack
        call_sequence += [return_buffer]

    o = LLLnode.from_list(
        call_sequence,
        typ=sig.return_type,
        location="memory",
        pos=pos,
        annotation=stmt_expr.get("node_source_code"),
        add_gas_estimate=sig.gas,
    )
    o.is_self_call = True
    return o
예제 #12
0
파일: expr.py 프로젝트: subhrajyoti21/viper
    def build_in_comparator(self):
        left = Expr(self.expr.left, self.context).lll_node
        right = Expr(self.expr.right, self.context).lll_node

        # temporary kludge to block #2637 bug
        # TODO actually fix the bug
        if not isinstance(left.typ, BaseType):
            raise TypeMismatch(
                "`in` not allowed for arrays of non-base types, tracked in issue #2637",
                self.expr)

        if isinstance(self.expr.op, vy_ast.In):
            found, not_found = 1, 0
        elif isinstance(self.expr.op, vy_ast.NotIn):
            found, not_found = 0, 1
        else:
            return  # pragma: notest

        i = LLLnode.from_list(self.context.fresh_varname("in_ix"),
                              typ="uint256")

        found_ptr = self.context.new_internal_variable(BaseType("bool"))

        ret = ["seq"]

        left = unwrap_location(left)
        with left.cache_when_complex("needle") as (
                b1, left), right.cache_when_complex("haystack") as (b2, right):
            if right.value == "multi":
                # Copy literal to memory to be compared.
                tmp_list = LLLnode.from_list(
                    self.context.new_internal_variable(right.typ),
                    typ=right.typ,
                    location="memory",
                )
                ret.append(make_setter(tmp_list, right, pos=getpos(self.expr)))

                right = tmp_list

            # location of i'th item from list
            pos = getpos(self.expr)
            ith_element_ptr = get_element_ptr(right,
                                              i,
                                              array_bounds_check=False,
                                              pos=pos)
            ith_element = unwrap_location(ith_element_ptr)

            if isinstance(right.typ, SArrayType):
                len_ = right.typ.count
            else:
                len_ = get_dyn_array_count(right)

            # Condition repeat loop has to break on.
            # TODO maybe put result on the stack
            loop_body = [
                "if",
                ["eq", left, ith_element],
                ["seq", ["mstore", found_ptr, found], "break"],  # store true.
            ]
            loop = ["repeat", i, 0, len_, right.typ.count, loop_body]

            ret.append([
                "seq",
                ["mstore", found_ptr, not_found],
                loop,
                ["mload", found_ptr],
            ])

            return LLLnode.from_list(b1.resolve(b2.resolve(ret)), typ="bool")
예제 #13
0
def make_return_stmt(lll_val: LLLnode, stmt: Any,
                     context: Context) -> Optional[LLLnode]:

    sig = context.sig

    jump_to_exit = ["exit_to", f"_sym_{sig.exit_sequence_label}"]

    _pos = getpos(stmt)

    if context.return_type is None:
        if stmt.value is not None:
            return None  # triggers an exception

    else:
        # sanity typecheck
        check_assign(dummy_node_for_type(context.return_type), lll_val)

    # helper function
    def finalize(fill_return_buffer):
        # do NOT bypass this. jump_to_exit may do important function cleanup.
        fill_return_buffer = LLLnode.from_list(
            fill_return_buffer,
            annotation=f"fill return buffer {sig._lll_identifier}")
        cleanup_loops = "cleanup_repeat" if context.forvars else "pass"
        # NOTE: because stack analysis is incomplete, cleanup_repeat must
        # come after fill_return_buffer otherwise the stack will break
        return LLLnode.from_list(
            ["seq", fill_return_buffer, cleanup_loops, jump_to_exit],
            pos=_pos,
        )

    if context.return_type is None:
        jump_to_exit += ["return_pc"]
        return finalize(["pass"])

    if context.is_internal:
        dst = LLLnode.from_list(["return_buffer"],
                                typ=context.return_type,
                                location="memory")
        fill_return_buffer = make_setter(dst, lll_val, pos=_pos)
        jump_to_exit += ["return_pc"]

        return finalize(fill_return_buffer)

    else:  # return from external function

        lll_val = wrap_value_for_external_return(lll_val)

        external_return_type = calculate_type_for_external_return(
            context.return_type)
        maxlen = external_return_type.abi_type.size_bound()
        return_buffer_ofst = context.new_internal_variable(
            get_type_for_exact_size(maxlen))

        # encode_out is cleverly a sequence which does the abi-encoding and
        # also returns the length of the output as a stack element
        encode_out = abi_encode(return_buffer_ofst,
                                lll_val,
                                context,
                                pos=_pos,
                                returns_len=True,
                                bufsz=maxlen)

        # previously we would fill the return buffer and push the location and length onto the stack
        # inside of the `seq_unchecked` thereby leaving it for the function cleanup routine expects
        # the return_ofst and return_len to be on the stack
        # CMC introduced `goto` with args so this enables us to replace `seq_unchecked` w/ `seq`
        # and then just append the arguments for the cleanup to the `jump_to_exit` list
        # check in vyper/codegen/self_call.py for an example
        jump_to_exit += [return_buffer_ofst, encode_out]  # type: ignore

        return finalize(["pass"])
예제 #14
0
def abi_encode(dst, ir_node, context, bufsz, returns_len=False):
    # TODO change dst to be an IRnode so it has type info to begin with.
    # setting the typ of dst to ir_node.typ is a footgun.
    dst = IRnode.from_list(dst, typ=ir_node.typ, location=MEMORY)
    abi_t = dst.typ.abi_type
    size_bound = abi_t.size_bound()

    assert isinstance(bufsz, int)
    if bufsz < size_bound:
        raise CompilerPanic("buffer provided to abi_encode not large enough")

    if size_bound < dst.typ.memory_bytes_required:
        raise CompilerPanic("Bad ABI size calc")

    annotation = f"abi_encode {ir_node.typ}"
    ir_ret = ["seq"]

    # fastpath: if there is no dynamic data, we can optimize the
    # encoding by using make_setter, since our memory encoding happens
    # to be identical to the ABI encoding.
    if abi_encoding_matches_vyper(ir_node.typ):
        # NOTE: make_setter handles changes of location and encoding
        ir_ret.append(make_setter(dst, ir_node))
        if returns_len:
            assert abi_t.embedded_static_size(
            ) == ir_node.typ.memory_bytes_required
            ir_ret.append(abi_t.embedded_static_size())
        return IRnode.from_list(ir_ret, annotation=annotation)

    # contains some computation, we need to only do it once.
    with ir_node.cache_when_complex("to_encode") as (
            b1, ir_node), dst.cache_when_complex("dst") as (b2, dst):

        dyn_ofst = "dyn_ofst"  # current offset in the dynamic section

        if isinstance(ir_node.typ, BaseType):
            ir_ret.append(make_setter(dst, ir_node))
        elif isinstance(ir_node.typ, ByteArrayLike):
            # TODO optimize out repeated ceil32 calculation
            ir_ret.append(make_setter(dst, ir_node))
            ir_ret.append(zero_pad(dst))
        elif isinstance(ir_node.typ, DArrayType):
            ir_ret.append(_encode_dyn_array_helper(dst, ir_node, context))

        elif isinstance(ir_node.typ, (TupleLike, SArrayType)):
            static_ofst = 0
            elems = _deconstruct_complex_type(ir_node)
            for e in elems:
                encode_ir = _encode_child_helper(dst, e, static_ofst, dyn_ofst,
                                                 context)
                ir_ret.extend(encode_ir)
                static_ofst += e.typ.abi_type.embedded_static_size()

        else:
            raise CompilerPanic(f"unencodable type: {ir_node.typ}")

        # declare IR variables.
        if returns_len:
            if not abi_t.is_dynamic():
                ir_ret.append(abi_t.embedded_static_size())
            elif isinstance(ir_node.typ, ByteArrayLike):
                # for abi purposes, return zero-padded length
                calc_len = ["ceil32", ["add", 32, ["mload", dst]]]
                ir_ret.append(calc_len)
            elif abi_t.is_complex_type():
                ir_ret.append("dyn_ofst")
            else:
                raise CompilerPanic(f"unknown type {ir_node.typ}")

        if abi_t.is_dynamic() and abi_t.is_complex_type():
            dyn_section_start = abi_t.static_size()
            ir_ret = ["with", dyn_ofst, dyn_section_start, ir_ret]
        else:
            pass  # skip dyn_ofst allocation if we don't need it

        return b1.resolve(
            b2.resolve(IRnode.from_list(ir_ret, annotation=annotation)))
예제 #15
0
def make_return_stmt(ir_val: IRnode, stmt: Any, context: Context) -> Optional[IRnode]:

    sig = context.sig

    jump_to_exit = ["exit_to", f"_sym_{sig.exit_sequence_label}"]

    if context.return_type is None:
        if stmt.value is not None:
            return None  # triggers an exception

    else:
        # sanity typecheck
        check_assign(dummy_node_for_type(context.return_type), ir_val)

    # helper function
    def finalize(fill_return_buffer):
        # do NOT bypass this. jump_to_exit may do important function cleanup.
        fill_return_buffer = IRnode.from_list(
            fill_return_buffer, annotation=f"fill return buffer {sig._ir_identifier}"
        )
        cleanup_loops = "cleanup_repeat" if context.forvars else "pass"
        # NOTE: because stack analysis is incomplete, cleanup_repeat must
        # come after fill_return_buffer otherwise the stack will break
        return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit])

    if context.return_type is None:
        jump_to_exit += ["return_pc"]
        return finalize(["pass"])

    if context.is_internal:
        dst = IRnode.from_list(["return_buffer"], typ=context.return_type, location=MEMORY)
        fill_return_buffer = make_setter(dst, ir_val)
        jump_to_exit += ["return_pc"]

        return finalize(fill_return_buffer)

    else:  # return from external function

        external_return_type = calculate_type_for_external_return(context.return_type)
        maxlen = external_return_type.abi_type.size_bound()

        # optimize: if the value already happens to be ABI encoded in
        # memory, don't bother running abi_encode, just return the
        # buffer it is in.
        can_skip_encode = (
            abi_encoding_matches_vyper(ir_val.typ)
            and ir_val.location == MEMORY
            # ensure it has already been validated - could be
            # unvalidated ABI encoded returndata for example
            and not needs_clamp(ir_val.typ, ir_val.encoding)
        )

        if can_skip_encode:
            assert ir_val.typ.memory_bytes_required == maxlen  # type: ignore
            jump_to_exit += [ir_val, maxlen]  # type: ignore
            return finalize(["pass"])

        ir_val = wrap_value_for_external_return(ir_val)

        # general case: abi_encode the data to a newly allocated buffer
        # and return the buffer
        return_buffer_ofst = context.new_internal_variable(get_type_for_exact_size(maxlen))

        # encode_out is cleverly a sequence which does the abi-encoding and
        # also returns the length of the output as a stack element
        return_len = abi_encode(return_buffer_ofst, ir_val, context, returns_len=True, bufsz=maxlen)

        # append ofst and len to exit_to the cleanup subroutine
        jump_to_exit += [return_buffer_ofst, return_len]  # type: ignore

        return finalize(["pass"])
예제 #16
0
파일: module.py 프로젝트: skellet0r/vyper
def parse_regular_functions(
    o, regular_functions, sigs, external_interfaces, global_ctx, default_function
):
    # check for payable/nonpayable external functions to optimize nonpayable assertions
    func_types = [i._metadata["type"] for i in global_ctx._defs]
    mutabilities = [i.mutability for i in func_types if i.visibility == FunctionVisibility.EXTERNAL]
    has_payable = any(i == StateMutability.PAYABLE for i in mutabilities)
    has_nonpayable = any(i != StateMutability.PAYABLE for i in mutabilities)

    is_default_payable = (
        default_function is not None
        and default_function._metadata["type"].mutability == StateMutability.PAYABLE
    )

    # TODO streamline the nonpayable check logic

    # when a contract has a payable default function and at least one nonpayable
    # external function, we must perform the nonpayable check on every function
    check_per_function = is_default_payable and has_nonpayable

    # generate LLL for regular functions
    payable_funcs = []
    nonpayable_funcs = []
    internal_funcs = []
    add_gas = func_init_lll().gas

    for func_node in regular_functions:
        func_type = func_node._metadata["type"]
        func_lll, frame_start, frame_size = generate_lll_for_function(
            func_node, {**{"self": sigs}, **external_interfaces}, global_ctx, check_per_function
        )

        if func_type.visibility == FunctionVisibility.INTERNAL:
            internal_funcs.append(func_lll)

        elif func_type.mutability == StateMutability.PAYABLE:
            add_gas += 30  # CMC 20210910 why?
            payable_funcs.append(func_lll)

        else:
            add_gas += 30  # CMC 20210910 why?
            nonpayable_funcs.append(func_lll)

        func_lll.total_gas += add_gas

        # update sigs with metadata gathered from compiling the function so that
        # we can handle calls to self
        # TODO we only need to do this for internal functions; external functions
        # cannot be called via `self`
        sig = FunctionSignature.from_definition(func_node, external_interfaces, global_ctx._structs)
        sig.gas = func_lll.total_gas
        sig.frame_start = frame_start
        sig.frame_size = frame_size
        sigs[sig.name] = sig

    # generate LLL for fallback function
    if default_function:
        fallback_lll, _frame_start, _frame_size = generate_lll_for_function(
            default_function,
            {**{"self": sigs}, **external_interfaces},
            global_ctx,
            # include a nonpayble check here if the contract only has a default function
            check_per_function or not regular_functions,
        )
    else:
        fallback_lll = LLLnode.from_list(["revert", 0, 0], typ=None, annotation="Default function")

    if check_per_function:
        external_seq = ["seq"] + payable_funcs + nonpayable_funcs

    else:
        # payable functions are placed prior to nonpayable functions
        # and seperated by a nonpayable assertion
        external_seq = ["seq"]
        if has_payable:
            external_seq += payable_funcs
        if has_nonpayable:
            external_seq.append(["assert", ["iszero", "callvalue"]])
            external_seq += nonpayable_funcs

    # bytecode is organized by: external functions, fallback fn, internal functions
    # this way we save gas and reduce bytecode by not jumping over internal functions
    runtime = [
        "seq",
        func_init_lll(),
        ["with", "_calldata_method_id", ["mload", 0], external_seq],
        ["seq", ["label", "fallback"], fallback_lll],
    ]
    runtime.extend(internal_funcs)

    immutables = [_global for _global in global_ctx._globals.values() if _global.is_immutable]

    # TODO: enable usage of the data section beyond just user defined immutables
    # https://github.com/vyperlang/vyper/pull/2466#discussion_r722816358
    if len(immutables) > 0:
        # find position of the last immutable so we do not overwrite it in memory
        # when we codecopy the runtime code to memory
        immutables = sorted(immutables, key=lambda imm: imm.pos)
        start_pos = immutables[-1].pos + immutables[-1].size * 32

        # create sequence of actions to copy immutables to the end of the runtime code in memory
        # TODO: if possible, just use identity precompile
        data_section = []
        for immutable in immutables:
            # store each immutable at the end of the runtime code
            memory_loc, offset = (
                immutable.pos,
                immutable.data_offset,
            )
            lhs = LLLnode.from_list(
                ["add", start_pos + offset, "_lllsz"], typ=immutable.typ, location="memory"
            )
            rhs = LLLnode.from_list(memory_loc, typ=immutable.typ, location="memory")
            data_section.append(make_setter(lhs, rhs, pos=None))

        # TODO: use GlobalContext.immutable_section_size
        data_section_size = sum([immutable.size * 32 for immutable in immutables])
        o.append(
            [
                "with",
                "_lllsz",  # keep size of runtime bytecode in sz var
                ["lll", start_pos, runtime],  # store runtime code at `start_pos`
                # sequence of copying immutables, with final action of returning the runtime code
                ["seq", *data_section, ["return", start_pos, ["add", data_section_size, "_lllsz"]]],
            ]
        )

    else:
        # NOTE: lll macro first argument is the location in memory to store
        # the compiled bytecode
        # https://lll-docs.readthedocs.io/en/latest/lll_reference.html#code-lll
        o.append(["return", 0, ["lll", 0, runtime]])

    return o, runtime