Example #1
0
def lll_for_self_call(stmt_expr, context):
    from vyper.codegen.expr import Expr  # TODO rethink this circular import

    pos = getpos(stmt_expr)

    # ** Internal Call **
    # Steps:
    # - copy arguments into the soon-to-be callee
    # - allocate return buffer
    # - push jumpdest (callback ptr) and return buffer location
    # - jump to label
    # - (private function will fill return buffer and jump back)

    method_name = stmt_expr.func.attr

    pos_args_lll = [Expr(x, context).lll_node for x in stmt_expr.args]

    sig, kw_vals = context.lookup_internal_function(method_name, pos_args_lll)

    kw_args_lll = [Expr(x, context).lll_node for x in kw_vals]

    args_lll = pos_args_lll + kw_args_lll

    args_tuple_t = TupleType([x.typ for x in args_lll])
    args_as_tuple = LLLnode.from_list(["multi"] + [x for x in args_lll],
                                      typ=args_tuple_t)

    # register callee to help calculate our starting frame offset
    context.register_callee(sig.frame_size)

    if context.is_constant() and sig.mutability not in ("view", "pure"):
        raise StateAccessViolation(
            f"May not call state modifying function "
            f"'{method_name}' within {context.pp_constancy()}.",
            getpos(stmt_expr),
        )

    # TODO move me to type checker phase
    if not sig.internal:
        raise StructureException("Cannot call external functions via 'self'",
                                 stmt_expr)

    return_label = _generate_label(f"{sig.internal_function_label}_call")

    # allocate space for the return buffer
    # TODO allocate in stmt and/or expr.py
    return_buffer = (context.new_internal_variable(sig.return_type)
                     if sig.return_type is not None else "pass")
    return_buffer = LLLnode.from_list([return_buffer],
                                      annotation=f"{return_label}_return_buf")

    # note: dst_tuple_t != args_tuple_t
    dst_tuple_t = TupleType([arg.typ for arg in sig.args])
    args_dst = LLLnode(sig.frame_start, typ=dst_tuple_t, location="memory")

    # if one of the arguments is a self call, the argument
    # buffer could get borked. to prevent against that,
    # write args to a temporary buffer until all the arguments
    # are fully evaluated.
    if args_as_tuple.contains_self_call:
        copy_args = ["seq"]
        # TODO deallocate me
        tmp_args_buf = LLLnode(
            context.new_internal_variable(dst_tuple_t),
            typ=dst_tuple_t,
            location="memory",
        )
        copy_args.append(
            # --> args evaluate here <--
            make_setter(tmp_args_buf, args_as_tuple, pos))

        copy_args.append(make_setter(args_dst, tmp_args_buf, pos))

    else:
        copy_args = make_setter(args_dst, args_as_tuple, pos)

    call_sequence = [
        "seq",
        copy_args,
        [
            "goto",
            sig.internal_function_label,
            return_buffer,  # pass return buffer to subroutine
            push_label_to_stack(
                return_label),  # pass return label to subroutine
        ],
        ["label", return_label],
        return_buffer,  # push return buffer location to stack
    ]

    o = LLLnode.from_list(
        call_sequence,
        typ=sig.return_type,
        location="memory",
        pos=pos,
        annotation=stmt_expr.get("node_source_code"),
        add_gas_estimate=sig.gas,
    )
    o.is_self_call = True
    return o
Example #2
0
def test_compile_lll_good(good_lll, get_contract_from_lll):
    get_contract_from_lll(LLLnode.from_list(good_lll))
Example #3
0
def test_pc_debugger():
    debugger_lll = ["seq", ["mstore", 0, 32], ["pc_debugger"]]
    lll_nodes = LLLnode.from_list(debugger_lll)
    _, line_number_map = compile_lll.assembly_to_evm(compile_lll.compile_to_assembly(lll_nodes))
    assert line_number_map["pc_breakpoints"][0] == 5
Example #4
0
def _compile_to_assembly(code,
                         withargs=None,
                         existing_labels=None,
                         break_dest=None,
                         height=0):
    if withargs is None:
        withargs = {}
    if not isinstance(withargs, dict):
        raise CompilerPanic(f"Incorrect type for withargs: {type(withargs)}")

    if existing_labels is None:
        existing_labels = set()
    if not isinstance(existing_labels, set):
        raise CompilerPanic(
            f"Incorrect type for existing_labels: {type(existing_labels)}")

    # Opcodes
    if isinstance(code.value, str) and code.value.upper() in get_opcodes():
        o = []
        for i, c in enumerate(code.args[::-1]):
            o.extend(
                _compile_to_assembly(c, withargs, existing_labels, break_dest,
                                     height + i))
        o.append(code.value.upper())
        return o
    # Numbers
    elif isinstance(code.value, int):
        if code.value < -(2**255):
            raise Exception(f"Value too low: {code.value}")
        elif code.value >= 2**256:
            raise Exception(f"Value too high: {code.value}")
        bytez = num_to_bytearray(code.value % 2**256) or [0]
        return ["PUSH" + str(len(bytez))] + bytez
    # Variables connected to with statements
    elif isinstance(code.value, str) and code.value in withargs:
        if height - withargs[code.value] > 16:
            raise Exception("With statement too deep")
        return ["DUP" + str(height - withargs[code.value])]
    # Setting variables connected to with statements
    elif code.value == "set":
        if len(code.args) != 2 or code.args[0].value not in withargs:
            raise Exception(
                "Set expects two arguments, the first being a stack variable")
        if height - withargs[code.args[0].value] > 16:
            raise Exception("With statement too deep")
        return _compile_to_assembly(
            code.args[1], withargs, existing_labels, break_dest, height) + [
                "SWAP" + str(height - withargs[code.args[0].value]),
                "POP",
            ]
    # Pass statements
    elif code.value in ("pass", "dummy"):
        return []
    # Code length
    elif code.value == "~codelen":
        return ["_sym_codeend"]
    # Calldataload equivalent for code
    elif code.value == "codeload":
        return _compile_to_assembly(
            LLLnode.from_list([
                "seq",
                ["codecopy", MemoryPositions.FREE_VAR_SPACE, code.args[0], 32],
                ["mload", MemoryPositions.FREE_VAR_SPACE],
            ]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # If statements (2 arguments, ie. if x: y)
    elif code.value == "if" and len(code.args) == 2:
        o = []
        o.extend(
            _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height))
        end_symbol = mksymbol("join")
        o.extend(["ISZERO", end_symbol, "JUMPI"])
        o.extend(
            _compile_to_assembly(code.args[1], withargs, existing_labels,
                                 break_dest, height))
        o.extend([end_symbol, "JUMPDEST"])
        return o
    # If statements (3 arguments, ie. if x: y, else: z)
    elif code.value == "if" and len(code.args) == 3:
        o = []
        o.extend(
            _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height))
        mid_symbol = mksymbol("else")
        end_symbol = mksymbol("join")
        o.extend(["ISZERO", mid_symbol, "JUMPI"])
        o.extend(
            _compile_to_assembly(code.args[1], withargs, existing_labels,
                                 break_dest, height))
        o.extend([end_symbol, "JUMP", mid_symbol, "JUMPDEST"])
        o.extend(
            _compile_to_assembly(code.args[2], withargs, existing_labels,
                                 break_dest, height))
        o.extend([end_symbol, "JUMPDEST"])
        return o

    # repeat(counter_location, start, rounds, rounds_bound, body)
    # basically a do-while loop:
    #
    # assert(rounds <= rounds_bound)
    # if (rounds > 0) {
    #   do {
    #     body;
    #   } while (++i != start + rounds)
    # }
    elif code.value == "repeat":
        o = []
        if len(code.args) != 5:
            raise CompilerPanic("bad number of repeat args")  # pragma: notest

        i_name = code.args[0]
        start = code.args[1]
        rounds = code.args[2]
        rounds_bound = code.args[3]
        body = code.args[4]

        entry_dest, continue_dest, exit_dest = (
            mksymbol("loop_start"),
            mksymbol("loop_continue"),
            mksymbol("loop_exit"),
        )

        # stack: []
        o.extend(
            _compile_to_assembly(
                start,
                withargs,
                existing_labels,
                break_dest,
                height,
            ))

        o.extend(
            _compile_to_assembly(rounds, withargs, existing_labels, break_dest,
                                 height + 1))

        # stack: i

        # assert rounds <= round_bound
        if rounds != rounds_bound:
            # stack: i, rounds
            o.extend(
                _compile_to_assembly(rounds_bound, withargs, existing_labels,
                                     break_dest, height + 2))
            # stack: i, rounds, rounds_bound
            # assert rounds <= rounds_bound
            # TODO this runtime assertion should never fail for
            # internally generated repeats.
            # maybe drop it or jump to 0xFE
            o.extend(["DUP2", "GT", "_sym_revert0", "JUMPI"])

            # stack: i, rounds
            # if (0 == rounds) { goto end_dest; }
            o.extend(["DUP1", "ISZERO", exit_dest, "JUMPI"])

        # stack: start, rounds
        if start.value != 0:
            o.extend(["DUP2", "ADD"])

        # stack: i, exit_i
        o.extend(["SWAP1"])

        if i_name.value in withargs:
            raise CompilerPanic(f"shadowed loop variable {i_name}")
        withargs[i_name.value] = height + 1

        # stack: exit_i, i
        o.extend([entry_dest, "JUMPDEST"])
        o.extend(
            _compile_to_assembly(
                body,
                withargs,
                existing_labels,
                (exit_dest, continue_dest, height + 2),
                height + 2,
            ))

        del withargs[i_name.value]

        # clean up any stack items left by body
        o.extend(["POP"] * body.valency)

        # stack: exit_i, i
        # increment i:
        o.extend([continue_dest, "JUMPDEST", "PUSH1", 1, "ADD"])

        # stack: exit_i, i+1 (new_i)
        # if (exit_i != new_i) { goto entry_dest }
        o.extend(["DUP2", "DUP2", "XOR", entry_dest, "JUMPI"])
        o.extend([exit_dest, "JUMPDEST", "POP", "POP"])

        return o

    # Continue to the next iteration of the for loop
    elif code.value == "continue":
        if not break_dest:
            raise CompilerPanic("Invalid break")
        dest, continue_dest, break_height = break_dest
        return [continue_dest, "JUMP"]
    # Break from inside a for loop
    elif code.value == "break":
        if not break_dest:
            raise CompilerPanic("Invalid break")
        dest, continue_dest, break_height = break_dest

        n_local_vars = height - break_height
        # clean up any stack items declared in the loop body
        cleanup_local_vars = ["POP"] * n_local_vars
        return cleanup_local_vars + [dest, "JUMP"]
    # Break from inside one or more for loops prior to a return statement inside the loop
    elif code.value == "cleanup_repeat":
        if not break_dest:
            raise CompilerPanic("Invalid break")
        _, _, break_height = break_dest
        # clean up local vars and internal loop vars
        return ["POP"] * break_height
    # With statements
    elif code.value == "with":
        o = []
        o.extend(
            _compile_to_assembly(code.args[1], withargs, existing_labels,
                                 break_dest, height))
        old = withargs.get(code.args[0].value, None)
        withargs[code.args[0].value] = height
        o.extend(
            _compile_to_assembly(
                code.args[2],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        if code.args[2].valency:
            o.extend(["SWAP1", "POP"])
        else:
            o.extend(["POP"])
        if old is not None:
            withargs[code.args[0].value] = old
        else:
            del withargs[code.args[0].value]
        return o
    # LLL statement (used to contain code inside code)
    elif code.value == "lll":
        o = []
        begincode = mksymbol("lll_begin")
        endcode = mksymbol("lll_end")
        o.extend([endcode, "JUMP", begincode, "BLANK"])

        lll = _compile_to_assembly(code.args[1], {}, existing_labels, None, 0)

        # `append(...)` call here is intentional.
        # each sublist is essentially its own program with its
        # own symbols.
        # in the later step when the "lll" block compiled to EVM,
        # compile_to_evm has logic to resolve symbols in "lll" to
        # position from start of runtime-code (instead of position
        # from start of bytecode).
        o.append(lll)

        o.extend([endcode, "JUMPDEST", begincode, endcode, "SUB", begincode])
        o.extend(
            _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height))

        # COPY the code to memory for deploy
        o.extend(["CODECOPY", begincode, endcode, "SUB"])
        return o
    # Seq (used to piece together multiple statements)
    elif code.value == "seq":
        o = []
        for arg in code.args:
            o.extend(
                _compile_to_assembly(arg, withargs, existing_labels,
                                     break_dest, height))
            if arg.valency == 1 and arg != code.args[-1]:
                o.append("POP")
        return o
    # Seq without popping.
    # Assure (if false, invalid opcode)
    elif code.value == "assert_unreachable":
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        end_symbol = mksymbol("reachable")
        o.extend([end_symbol, "JUMPI", "INVALID", end_symbol, "JUMPDEST"])
        return o
    # Assert (if false, exit)
    elif code.value == "assert":
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        o.extend(["ISZERO"])
        o.extend(_assert_false())
        return o
    # Unsigned/signed clamp, check less-than
    elif code.value in CLAMP_OP_NAMES:
        if isinstance(code.args[0].value, int) and isinstance(
                code.args[1].value, int):
            # Checks for clamp errors at compile time as opposed to run time
            # TODO move these to optimizer.py
            args_0_val = code.args[0].value
            args_1_val = code.args[1].value
            is_free_of_clamp_errors = any((
                code.value in ("uclamplt", "clamplt")
                and 0 <= args_0_val < args_1_val,
                code.value in ("uclample", "clample")
                and 0 <= args_0_val <= args_1_val,
                code.value in ("uclampgt", "clampgt")
                and 0 <= args_0_val > args_1_val,
                code.value in ("uclampge", "clampge")
                and 0 <= args_0_val >= args_1_val,
            ))
            if is_free_of_clamp_errors:
                return _compile_to_assembly(
                    code.args[0],
                    withargs,
                    existing_labels,
                    break_dest,
                    height,
                )
            else:
                raise Exception(
                    f"Invalid {code.value} with values {code.args[0]} and {code.args[1]}"
                )
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        o.extend(
            _compile_to_assembly(
                code.args[1],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        o.extend(["DUP2"])
        # Stack: num num bound
        if code.value == "uclamplt":
            o.extend(["LT", "ISZERO"])
        elif code.value == "clamplt":
            o.extend(["SLT", "ISZERO"])
        elif code.value == "uclample":
            o.extend(["GT"])
        elif code.value == "clample":
            o.extend(["SGT"])
        elif code.value == "uclampgt":
            o.extend(["GT", "ISZERO"])
        elif code.value == "clampgt":
            o.extend(["SGT", "ISZERO"])
        elif code.value == "uclampge":
            o.extend(["LT"])
        elif code.value == "clampge":
            o.extend(["SLT"])
        o.extend(_assert_false())
        return o
    # Signed clamp, check against upper and lower bounds
    elif code.value in ("clamp", "uclamp"):
        comp1 = "SGT" if code.value == "clamp" else "GT"
        comp2 = "SLT" if code.value == "clamp" else "LT"
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        o.extend(
            _compile_to_assembly(
                code.args[1],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        o.extend(["DUP1"])
        o.extend(
            _compile_to_assembly(
                code.args[2],
                withargs,
                existing_labels,
                break_dest,
                height + 3,
            ))
        o.extend(["SWAP1", comp1])
        o.extend(_assert_false())
        o.extend(["DUP1", "SWAP2", "SWAP1", comp2])
        o.extend(_assert_false())
        return o
    # Checks that a value is nonzero
    elif code.value == "clamp_nonzero":
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        o.extend(["DUP1", "ISZERO"])
        o.extend(_assert_false())
        return o
    # SHA3 a single value
    elif code.value == "sha3_32":
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        o.extend([
            "PUSH1",
            MemoryPositions.FREE_VAR_SPACE,
            "MSTORE",
            "PUSH1",
            32,
            "PUSH1",
            MemoryPositions.FREE_VAR_SPACE,
            "SHA3",
        ])
        return o
    # SHA3 a 64 byte value
    elif code.value == "sha3_64":
        o = _compile_to_assembly(code.args[0], withargs, existing_labels,
                                 break_dest, height)
        o.extend(
            _compile_to_assembly(code.args[1], withargs, existing_labels,
                                 break_dest, height))
        o.extend([
            "PUSH1",
            MemoryPositions.FREE_VAR_SPACE2,
            "MSTORE",
            "PUSH1",
            MemoryPositions.FREE_VAR_SPACE,
            "MSTORE",
            "PUSH1",
            64,
            "PUSH1",
            MemoryPositions.FREE_VAR_SPACE,
            "SHA3",
        ])
        return o
    # <= operator
    elif code.value == "le":
        return _compile_to_assembly(
            LLLnode.from_list(["iszero", ["gt", code.args[0], code.args[1]]]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # >= operator
    elif code.value == "ge":
        return _compile_to_assembly(
            LLLnode.from_list(["iszero", ["lt", code.args[0], code.args[1]]]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # <= operator
    elif code.value == "sle":
        return _compile_to_assembly(
            LLLnode.from_list(["iszero", ["sgt", code.args[0], code.args[1]]]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # >= operator
    elif code.value == "sge":
        return _compile_to_assembly(
            LLLnode.from_list(["iszero", ["slt", code.args[0], code.args[1]]]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # != operator
    elif code.value == "ne":
        return _compile_to_assembly(
            LLLnode.from_list(["iszero", ["eq", code.args[0], code.args[1]]]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # e.g. 95 -> 96, 96 -> 96, 97 -> 128
    elif code.value == "ceil32":
        return _compile_to_assembly(
            LLLnode.from_list([
                "with",
                "_val",
                code.args[0],
                # in mod32 arithmetic, the solution to x + y == 32 is
                # y = bitwise_not(x) & 31
                ["add", "_val", ["and", ["not", ["sub", "_val", 1]], 31]],
            ]),
            withargs,
            existing_labels,
            break_dest,
            height,
        )
    # # jump to a symbol, and push variable arguments onto stack
    elif code.value == "goto":
        o = []
        for i, c in enumerate(reversed(code.args[1:])):
            o.extend(
                _compile_to_assembly(c, withargs, existing_labels, break_dest,
                                     height + i))
        o.extend(["_sym_" + str(code.args[0]), "JUMP"])
        return o
    elif isinstance(code.value, str) and is_symbol(code.value):
        return [code.value]
    # set a symbol as a location.
    elif code.value == "label":
        label_name = str(code.args[0])

        if label_name in existing_labels:
            raise Exception(f"Label with name {label_name} already exists!")
        else:
            existing_labels.add(label_name)

        return ["_sym_" + label_name, "JUMPDEST"]
    # inject debug opcode.
    elif code.value == "debugger":
        return mkdebug(pc_debugger=False, pos=code.pos)
    # inject debug opcode.
    elif code.value == "pc_debugger":
        return mkdebug(pc_debugger=True, pos=code.pos)
    else:
        raise Exception("Weird code element: " + repr(code))
Example #5
0
def test_lll_compile_fail(bad_lll, get_contract_from_lll, assert_compile_failed):
    assert_compile_failed(lambda: get_contract_from_lll(LLLnode.from_list(bad_lll)), Exception)
Example #6
0
def lll_tuple_from_args(args):
    typ = TupleType([x.typ for x in args])
    return LLLnode.from_list(["multi"] + [x for x in args], typ=typ)
Example #7
0
def _get_element_ptr_array(parent, key, pos, array_bounds_check):

    assert isinstance(parent.typ, ArrayLike)

    if not is_integer_type(key.typ):
        raise TypeCheckFailure(f"{key.typ} used as array index")

    subtype = parent.typ.subtype

    if parent.value == "~empty":
        if array_bounds_check:
            # this case was previously missing a bounds check. codegen
            # is a bit complicated when bounds check is required, so
            # block it. there is no reason to index into a literal empty
            # array anyways!
            raise TypeCheckFailure("indexing into zero array not allowed")
        return LLLnode.from_list("~empty", subtype)

    if parent.value == "multi":
        assert isinstance(key.value, int)
        return parent.args[key.value]

    ix = unwrap_location(key)

    if array_bounds_check:
        # clamplt works, even for signed ints. since two's-complement
        # is used, if the index is negative, (unsigned) LT will interpret
        # it as a very large number, larger than any practical value for
        # an array index, and the clamp will throw an error.
        clamp_op = "uclamplt"
        is_darray = isinstance(parent.typ, DArrayType)
        bound = get_dyn_array_count(parent) if is_darray else parent.typ.count
        # NOTE: there are optimization rules for this when ix or bound is literal
        ix = LLLnode.from_list([clamp_op, ix, bound], typ=ix.typ)

    if parent.encoding in (Encoding.ABI, Encoding.JSON_ABI):
        if parent.location == "storage":
            raise CompilerPanic("storage variables should not be abi encoded"
                                )  # pragma: notest

        member_abi_t = subtype.abi_type

        ofst = _mul(ix, member_abi_t.embedded_static_size())

        return _getelemptr_abi_helper(parent, subtype, ofst, pos)

    if parent.location == "storage":
        element_size = subtype.storage_size_in_words
    elif parent.location in ("calldata", "memory", "data", "immutables"):
        element_size = subtype.memory_bytes_required

    ofst = _mul(ix, element_size)

    if has_length_word(parent.typ):
        data_ptr = add_ofst(parent,
                            wordsize(parent.location) * DYNAMIC_ARRAY_OVERHEAD)
    else:
        data_ptr = parent

    return LLLnode.from_list(add_ofst(data_ptr, ofst),
                             typ=subtype,
                             location=parent.location,
                             pos=pos)
Example #8
0
def get_bytearray_length(arg):
    typ = BaseType("uint256")
    return LLLnode.from_list([load_op(arg.location), arg], typ=typ)
Example #9
0
def copy_bytes(dst, src, length, length_bound, pos=None):
    annotation = f"copy_bytes from {src} to {dst}"

    src = LLLnode.from_list(src)
    dst = LLLnode.from_list(dst)
    length = LLLnode.from_list(length)

    with src.cache_when_complex("src") as (
            b1, src), length.cache_when_complex("copy_word_count") as (
                b2,
                length,
            ), dst.cache_when_complex("dst") as (b3, dst):

        # fast code for common case where num bytes is small
        # TODO expand this for more cases where num words is less than ~8
        if length_bound <= 32:
            copy_op = [
                store_op(dst.location), dst, [load_op(src.location), src]
            ]
            ret = LLLnode.from_list(copy_op, annotation=annotation)
            return b1.resolve(b2.resolve(b3.resolve(ret)))

        if dst.location == "memory" and src.location in ("memory", "calldata",
                                                         "data"):
            # special cases: batch copy to memory
            # TODO: iloadbytes
            if src.location == "memory":
                copy_op = ["staticcall", "gas", 4, src, length, dst, length]
                gas_bound = _identity_gas_bound(length_bound)
            elif src.location == "calldata":
                copy_op = ["calldatacopy", dst, src, length]
                gas_bound = _calldatacopy_gas_bound(length_bound)
            elif src.location == "data":
                copy_op = ["dloadbytes", dst, src, length]
                # note: dloadbytes compiles to CODECOPY
                gas_bound = _codecopy_gas_bound(length_bound)

            ret = LLLnode.from_list(copy_op,
                                    annotation=annotation,
                                    add_gas_estimate=gas_bound)
            return b1.resolve(b2.resolve(b3.resolve(ret)))

        if dst.location == "immutables" and src.location in ("memory", "data"):
            # TODO istorebytes-from-mem, istorebytes-from-calldata(?)
            # compile to identity, CODECOPY respectively.
            pass

        # general case, copy word-for-word
        # pseudocode for our approach (memory-storage as example):
        # for i in range(len, bound=MAX_LEN):
        #   sstore(_dst + i, mload(src + i * 32))
        # TODO should use something like
        # for i in range(len, bound=MAX_LEN):
        #   _dst += 1
        #   src += 32
        #   sstore(_dst, mload(src))

        i = LLLnode.from_list(_freshname("copy_bytes_ix"), typ="uint256")

        if src.location in ("memory", "calldata", "data", "immutables"):
            loader = [load_op(src.location), ["add", src, _mul(32, i)]]
        elif src.location == "storage":
            loader = [load_op(src.location), ["add", src, i]]
        else:
            raise CompilerPanic(
                f"Unsupported location: {src.location}")  # pragma: notest

        if dst.location in ("memory", "immutables"):
            setter = [
                store_op(dst.location), ["add", dst, _mul(32, i)], loader
            ]
        elif dst.location == "storage":
            setter = ["sstore", ["add", dst, i], loader]
        else:
            raise CompilerPanic(
                f"Unsupported location: {dst.location}")  # pragma: notest

        n = ["div", ["ceil32", length], 32]
        n_bound = ceil32(length_bound) // 32

        main_loop = ["repeat", i, 0, n, n_bound, setter]

        return b1.resolve(
            b2.resolve(
                b3.resolve(
                    LLLnode.from_list(main_loop,
                                      annotation=annotation,
                                      pos=pos))))
Example #10
0
def _dynarray_make_setter(dst, src, pos=None):
    assert isinstance(src.typ, DArrayType)
    assert isinstance(dst.typ, DArrayType)

    if src.value == "~empty":
        return LLLnode.from_list([store_op(dst.location), dst, 0], pos=pos)

    if src.value == "multi":
        ret = ["seq"]
        # handle literals

        # write the length word
        store_length = [store_op(dst.location), dst, len(src.args)]
        ann = None
        if src.annotation is not None:
            ann = f"len({src.annotation})"
        store_length = LLLnode.from_list(store_length, annotation=ann)
        ret.append(store_length)

        n_items = len(src.args)
        for i in range(n_items):
            k = LLLnode.from_list(i, typ="uint256")
            dst_i = get_element_ptr(dst, k, pos=pos, array_bounds_check=False)
            src_i = get_element_ptr(src, k, pos=pos, array_bounds_check=False)
            ret.append(make_setter(dst_i, src_i, pos))

        return ret

    with src.cache_when_complex("darray_src") as (b1, src):

        # for ABI-encoded dynamic data, we must loop to unpack, since
        # the layout does not match our memory layout
        should_loop = (src.encoding in (Encoding.ABI, Encoding.JSON_ABI)
                       and src.typ.subtype.abi_type.is_dynamic())

        # if the subtype is dynamic, there might be a lot of
        # unused space inside of each element. for instance
        # DynArray[DynArray[uint256, 100], 5] where all the child
        # arrays are empty - for this case, we recursively call
        # into make_setter instead of straight bytes copy
        # TODO we can make this heuristic more precise, e.g.
        # loop when subtype.is_dynamic AND location == storage
        # OR array_size <= /bound where loop is cheaper than memcpy/
        should_loop |= src.typ.subtype.abi_type.is_dynamic()
        should_loop |= _needs_clamp(src.typ.subtype, src.encoding)

        if should_loop:
            uint = BaseType("uint256")

            # note: name clobbering for the ix is OK because
            # we never reach outside our level of nesting
            i = LLLnode.from_list(_freshname("copy_darray_ix"), typ=uint)

            loop_body = make_setter(
                get_element_ptr(dst, i, array_bounds_check=False, pos=pos),
                get_element_ptr(src, i, array_bounds_check=False, pos=pos),
                pos=pos,
            )
            loop_body.annotation = f"{dst}[i] = {src}[i]"

            with get_dyn_array_count(src).cache_when_complex(
                    "darray_count") as (b2, len_):
                store_len = [store_op(dst.location), dst, len_]
                loop = ["repeat", i, 0, len_, src.typ.count, loop_body]

                return b1.resolve(b2.resolve(["seq", store_len, loop]))

        element_size = src.typ.subtype.memory_bytes_required
        # 32 bytes + number of elements * size of element in bytes
        n_bytes = ["add", _mul(get_dyn_array_count(src), element_size), 32]
        max_bytes = src.typ.memory_bytes_required

        return b1.resolve(copy_bytes(dst, src, n_bytes, max_bytes, pos=pos))