Ejemplo n.º 1
0
def to_bytes32(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _len = get_type(in_arg)

    if input_type == "Bytes":
        if _len > 32:
            raise TypeMismatch(
                f"Unable to convert bytes[{_len}] to bytes32, max length is too "
                "large.")

        with in_arg.cache_when_complex("bytes") as (b1, in_arg):
            op = load_op(in_arg.location)
            ofst = wordsize(in_arg.location) * DYNAMIC_ARRAY_OVERHEAD
            bytes_val = [op, ["add", in_arg, ofst]]

            # zero out any dirty bytes (which can happen in the last
            # word of a bytearray)
            len_ = get_bytearray_length(in_arg)
            num_zero_bits = LLLnode.from_list(["mul", ["sub", 32, len_], 8])
            with num_zero_bits.cache_when_complex("bits") as (b2,
                                                              num_zero_bits):
                ret = shl(num_zero_bits, shr(num_zero_bits, bytes_val))
                ret = b1.resolve(b2.resolve(ret))

    else:
        # literal
        ret = in_arg

    return LLLnode.from_list(ret, typ="bytes32", pos=getpos(expr))
Ejemplo n.º 2
0
def to_bytes_m(expr, arg, out_typ):
    out_info = out_typ._bytes_info

    _check_bytes(expr, arg, out_typ, max_bytes_allowed=out_info.m)

    if isinstance(arg.typ, ByteArrayType):
        bytes_val = LOAD(bytes_data_ptr(arg))

        # zero out any dirty bytes (which can happen in the last
        # word of a bytearray)
        len_ = get_bytearray_length(arg)
        num_zero_bits = IRnode.from_list(["mul", ["sub", 32, len_], 8])
        with num_zero_bits.cache_when_complex("bits") as (b, num_zero_bits):
            arg = shl(num_zero_bits, shr(num_zero_bits, bytes_val))
            arg = b.resolve(arg)

    elif is_bytes_m_type(arg.typ):
        arg_info = arg.typ._bytes_info
        # clamp if it's a downcast
        if arg_info.m > out_info.m:
            arg = bytes_clamp(arg, out_info.m)

    elif is_integer_type(arg.typ) or is_base_type(arg.typ, "address"):
        int_bits = arg.typ._int_info.bits

        if out_info.m_bits < int_bits:
            # question: allow with runtime clamp?
            # arg = int_clamp(m_bits, signed=int_info.signed)
            _FAIL(arg.typ, out_typ, expr)

        # note: neg numbers not OOB. keep sign bit
        arg = shl(256 - out_info.m_bits, arg)

    elif is_decimal_type(arg.typ):
        if out_info.m_bits < arg.typ._decimal_info.bits:
            _FAIL(arg.typ, out_typ, expr)

        # note: neg numbers not OOB. keep sign bit
        arg = shl(256 - out_info.m_bits, arg)

    else:
        # bool
        arg = shl(256 - out_info.m_bits, arg)

    return IRnode.from_list(arg, typ=out_typ)
Ejemplo n.º 3
0
def _bytes_to_num(arg, out_typ, signed):
    # converting a bytestring to a number:
    # bytestring and bytes_m are right-padded with zeroes, int is left-padded.
    # convert by shr or sar the number of zero bytes (converted to bits)
    # e.g. "abcd000000000000" -> bitcast(000000000000abcd, output_type)

    if isinstance(arg.typ, ByteArrayLike):
        _len = get_bytearray_length(arg)
        arg = LOAD(bytes_data_ptr(arg))
        num_zero_bits = ["mul", 8, ["sub", 32, _len]]
    elif is_bytes_m_type(arg.typ):
        info = arg.typ._bytes_info
        num_zero_bits = 8 * (32 - info.m)
    else:
        raise CompilerPanic("unreachable")  # pragma: notest

    if signed:
        ret = sar(num_zero_bits, arg)
    else:
        ret = shr(num_zero_bits, arg)

    annotation = (f"__intrinsic__byte_array_to_num({out_typ})", )
    return IRnode.from_list(ret, annotation=annotation)
Ejemplo n.º 4
0
def byte_array_to_num(arg, out_type):
    """
    Takes a <32 byte array as input, and outputs a number.
    """
    # the location of the bytestring
    bs_start = (LLLnode.from_list(
        "bs_start", typ=arg.typ, location=arg.location, encoding=arg.encoding)
                if arg.is_complex_lll else arg)

    if arg.location == "storage":
        len_ = get_bytearray_length(bs_start)
        data = LLLnode.from_list(["sload", add_ofst(bs_start, 1)],
                                 typ=BaseType("int256"))
    else:
        op = load_op(arg.location)
        len_ = LLLnode.from_list([op, bs_start], typ=BaseType("int256"))
        data = LLLnode.from_list([op, add_ofst(bs_start, 32)],
                                 typ=BaseType("int256"))

    # converting a bytestring to a number:
    # bytestring is right-padded with zeroes, int is left-padded.
    # convert by shr the number of zero bytes (converted to bits)
    # e.g. "abcd000000000000" -> bitcast(000000000000abcd, output_type)
    num_zero_bits = ["mul", 8, ["sub", 32, "len_"]]
    bitcasted = LLLnode.from_list(shr(num_zero_bits, "val"), typ=out_type)

    result = clamp_basetype(bitcasted)

    # TODO use cache_when_complex for these `with` values
    ret = ["with", "val", data, ["with", "len_", len_, result]]
    if arg.is_complex_lll:
        ret = ["with", "bs_start", arg, ret]
    return LLLnode.from_list(
        ret,
        typ=BaseType(out_type),
        annotation=f"__intrinsic__byte_array_to_num({out_type})",
    )
Ejemplo n.º 5
0
def parse_regular_functions(regular_functions, sigs, external_interfaces,
                            global_ctx, default_function, init_function):
    # check for payable/nonpayable external functions to optimize nonpayable assertions
    func_types = [i._metadata["type"] for i in global_ctx._defs]
    mutabilities = [
        i.mutability for i in func_types
        if i.visibility == FunctionVisibility.EXTERNAL
    ]
    has_payable = any(i == StateMutability.PAYABLE for i in mutabilities)
    has_nonpayable = any(i != StateMutability.PAYABLE for i in mutabilities)

    is_default_payable = (default_function is not None
                          and default_function._metadata["type"].mutability
                          == StateMutability.PAYABLE)

    # TODO streamline the nonpayable check logic

    # when a contract has a payable default function and at least one nonpayable
    # external function, we must perform the nonpayable check on every function
    check_per_function = is_default_payable and has_nonpayable

    # generate IR for regular functions
    payable_funcs = []
    nonpayable_funcs = []
    internal_funcs = []
    add_gas = 0

    for func_node in regular_functions:
        func_type = func_node._metadata["type"]
        func_ir, frame_start, frame_size = generate_ir_for_function(
            func_node, {
                **{
                    "self": sigs
                },
                **external_interfaces
            }, global_ctx, check_per_function)

        if func_type.visibility == FunctionVisibility.INTERNAL:
            internal_funcs.append(func_ir)

        elif func_type.mutability == StateMutability.PAYABLE:
            add_gas += 30  # CMC 20210910 why?
            payable_funcs.append(func_ir)

        else:
            add_gas += 30  # CMC 20210910 why?
            nonpayable_funcs.append(func_ir)

        func_ir.total_gas += add_gas

        # update sigs with metadata gathered from compiling the function so that
        # we can handle calls to self
        # TODO we only need to do this for internal functions; external functions
        # cannot be called via `self`
        sig = FunctionSignature.from_definition(func_node, external_interfaces,
                                                global_ctx._structs)
        sig.gas = func_ir.total_gas
        sig.frame_start = frame_start
        sig.frame_size = frame_size
        sigs[sig.name] = sig

    # generate IR for fallback function
    if default_function:
        fallback_ir, _frame_start, _frame_size = generate_ir_for_function(
            default_function,
            {
                **{
                    "self": sigs
                },
                **external_interfaces
            },
            global_ctx,
            # include a nonpayble check here if the contract only has a default function
            check_per_function or not regular_functions,
        )
    else:
        fallback_ir = IRnode.from_list(["revert", 0, 0],
                                       typ=None,
                                       annotation="Default function")

    if check_per_function:
        external_seq = ["seq"] + payable_funcs + nonpayable_funcs

    else:
        # payable functions are placed prior to nonpayable functions
        # and seperated by a nonpayable assertion
        external_seq = ["seq"]
        if has_payable:
            external_seq += payable_funcs
        if has_nonpayable:
            external_seq.append(["assert", ["iszero", "callvalue"]])
            external_seq += nonpayable_funcs

    # ensure the external jumptable section gets closed out
    # (for basic block hygiene and also for zksync interpreter)
    # NOTE: this jump gets optimized out in assembly since the
    # fallback label is the immediate next instruction,
    close_selector_section = ["goto", "fallback"]

    # bytecode is organized by: external functions, fallback fn, internal functions
    # this way we save gas and reduce bytecode by not jumping over internal functions
    runtime = [
        "seq",
        # check that calldatasize is at least 4, otherwise
        # calldataload will load zeros (cf. yellow paper).
        ["if", ["lt", "calldatasize", 4], ["goto", "fallback"]],
        [
            "with", "_calldata_method_id",
            shr(224, ["calldataload", 0]), external_seq
        ],
        close_selector_section,
        ["label", "fallback", ["var_list"], fallback_ir],
    ]
    runtime.extend(internal_funcs)

    return runtime
Ejemplo n.º 6
0
def _runtime_ir(runtime_functions, all_sigs, global_ctx):
    # categorize the runtime functions because we will organize the runtime
    # code into the following sections:
    # payable functions, nonpayable functions, fallback function, internal_functions
    internal_functions = [f for f in runtime_functions if _is_internal(f)]

    external_functions = [f for f in runtime_functions if not _is_internal(f)]
    default_function = next(
        (f for f in external_functions if _is_default_func(f)), None)

    # functions that need to go exposed in the selector section
    regular_functions = [
        f for f in external_functions if not _is_default_func(f)
    ]
    payables = [f for f in regular_functions if _is_payable(f)]
    nonpayables = [f for f in regular_functions if not _is_payable(f)]

    # create a map of the IR functions since they might live in both
    # runtime and deploy code (if init function calls them)
    internal_functions_map: Dict[str, IRnode] = {}

    for func_ast in internal_functions:
        func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx,
                                           False)
        internal_functions_map[func_ast.name] = func_ir

    # for some reason, somebody may want to deploy a contract with no
    # external functions, or more likely, a "pure data" contract which
    # contains immutables
    if len(external_functions) == 0:
        # TODO: prune internal functions in this case?
        runtime = ["seq"] + list(internal_functions_map.values())
        return runtime, internal_functions_map

    # note: if the user does not provide one, the default fallback function
    # reverts anyway. so it does not hurt to batch the payable check.
    default_is_nonpayable = default_function is None or not _is_payable(
        default_function)

    # when a contract has a nonpayable default function,
    # we can do a single check for all nonpayable functions
    batch_payable_check = len(nonpayables) > 0 and default_is_nonpayable
    skip_nonpayable_check = batch_payable_check

    selector_section = ["seq"]

    for func_ast in payables:
        func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx,
                                           False)
        selector_section.append(func_ir)

    if batch_payable_check:
        selector_section.append(["assert", ["iszero", "callvalue"]])

    for func_ast in nonpayables:
        func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx,
                                           skip_nonpayable_check)
        selector_section.append(func_ir)

    if default_function:
        fallback_ir = generate_ir_for_function(default_function, all_sigs,
                                               global_ctx,
                                               skip_nonpayable_check)
    else:
        fallback_ir = IRnode.from_list(["revert", 0, 0],
                                       annotation="Default function")

    # ensure the external jumptable section gets closed out
    # (for basic block hygiene and also for zksync interpreter)
    # NOTE: this jump gets optimized out in assembly since the
    # fallback label is the immediate next instruction,
    close_selector_section = ["goto", "fallback"]

    runtime = [
        "seq",
        # check that calldatasize is at least 4, otherwise
        # calldataload will load zeros (cf. yellow paper).
        ["if", ["lt", "calldatasize", 4], ["goto", "fallback"]],
        [
            "with", "_calldata_method_id",
            shr(224, ["calldataload", 0]), selector_section
        ],
        close_selector_section,
        ["label", "fallback", ["var_list"], fallback_ir],
    ]

    # TODO: prune unreachable functions?
    runtime.extend(internal_functions_map.values())

    return runtime, internal_functions_map