예제 #1
0
    def size_in_bytes(self):
        # the first slot (32 bytes) stores the actual length, and then we reserve
        # enough additional slots to store the data if it uses the max available length
        # because this data type is single-bytes, we make it so it takes the max 32 byte
        # boundary as it's size, instead of giving it a size that is not cleanly divisble by 32

        return 32 + ceil32(self.length)
예제 #2
0
    def from_declaration(cls, code, custom_units=None):
        name = code.target.id
        pos = 0

        if not is_varname_valid(name, custom_units=custom_units):
            raise EventDeclarationException("Event name invalid: " + name)
        # Determine the arguments, expects something of the form def foo(arg1: num, arg2: num ...
        args = []
        indexed_list = []
        topics_count = 1
        if code.annotation.args:
            keys = code.annotation.args[0].keys
            values = code.annotation.args[0].values
            for i in range(len(keys)):
                typ = values[i]
                arg = keys[i].id
                is_indexed = False
                # Check to see if argument is a topic
                if isinstance(typ, ast.Call) and typ.func.id == 'indexed':
                    typ = values[i].args[0]
                    indexed_list.append(True)
                    topics_count += 1
                    is_indexed = True
                else:
                    indexed_list.append(False)
                if isinstance(typ, ast.Subscript) and getattr(
                        typ.value, 'id', None
                ) == 'bytes' and typ.slice.value.n > 32 and is_indexed:
                    raise EventDeclarationException(
                        "Indexed arguments are limited to 32 bytes")
                if topics_count > 4:
                    raise EventDeclarationException(
                        "Maximum of 3 topics {} given".format(topics_count -
                                                              1), arg)
                if not isinstance(arg, str):
                    raise VariableDeclarationException("Argument name invalid",
                                                       arg)
                if not typ:
                    raise InvalidTypeException("Argument must have type", arg)
                if not is_varname_valid(arg, custom_units):
                    raise VariableDeclarationException(
                        "Argument name invalid or reserved: " + arg, arg)
                if arg in (x.name for x in args):
                    raise VariableDeclarationException(
                        "Duplicate function argument name: " + arg, arg)
                parsed_type = parse_type(typ, None, custom_units=custom_units)
                args.append(VariableRecord(arg, pos, parsed_type, False))
                if isinstance(parsed_type, ByteArrayType):
                    pos += ceil32(typ.slice.value.n)
                else:
                    pos += get_size_of_type(parsed_type) * 32
        sig = name + '(' + ','.join([
            canonicalize_type(arg.typ, indexed_list[pos])
            for pos, arg in enumerate(args)
        ]) + ')'  # noqa F812
        event_id = bytes_to_int(sha3(bytes(sig, 'utf-8')))
        return cls(name, args, indexed_list, event_id, sig)
예제 #3
0
    def from_declaration(cls, class_node, global_ctx):
        name = class_node.name
        pos = 0

        check_valid_varname(
            name,
            global_ctx._structs,
            global_ctx._constants,
            pos=class_node,
            error_prefix="Event name invalid. ",
            exc=EventDeclarationException,
        )

        args = []
        indexed_list = []
        if len(class_node.body) != 1 or not isinstance(class_node.body[0],
                                                       vy_ast.Pass):
            for node in class_node.body:
                arg_item = node.target
                arg = node.target.id
                typ = node.annotation

                if isinstance(typ,
                              vy_ast.Call) and typ.get("func.id") == "indexed":
                    indexed_list.append(True)
                    typ = typ.args[0]
                else:
                    indexed_list.append(False)
                check_valid_varname(
                    arg,
                    global_ctx._structs,
                    global_ctx._constants,
                    pos=arg_item,
                    error_prefix="Event argument name invalid or reserved.",
                )
                if arg in (x.name for x in args):
                    raise TypeCheckFailure(
                        f"Duplicate function argument name: {arg}")
                # Can struct be logged?
                parsed_type = global_ctx.parse_type(typ, None)
                args.append(VariableRecord(arg, pos, parsed_type, False))
                if isinstance(parsed_type, ByteArrayType):
                    pos += ceil32(typ.slice.value.n)
                else:
                    pos += get_size_of_type(parsed_type) * 32

        sig = (name + "(" + ",".join([
            canonicalize_type(arg.typ, indexed_list[pos])
            for pos, arg in enumerate(args)
        ]) + ")")  # noqa F812
        event_id = bytes_to_int(keccak256(bytes(sig, "utf-8")))
        return cls(name, args, indexed_list, event_id, sig)
예제 #4
0
파일: types.py 프로젝트: siromivel/vyper
def get_size_of_type(typ):
    if isinstance(typ, BaseType):
        return 1
    elif isinstance(typ, ByteArrayLike):
        return ceil32(typ.maxlen) // 32 + 2
    elif isinstance(typ, ListType):
        return get_size_of_type(typ.subtype) * typ.count
    elif isinstance(typ, MappingType):
        raise Exception("Maps are not supported for function arguments or outputs.")
    elif isinstance(typ, TupleLike):
        return sum([get_size_of_type(v) for v in typ.tuple_members()])
    else:
        raise Exception("Can not get size of type, Unexpected type: %r" % repr(typ))
예제 #5
0
def parse_global(stdout, global_vars, computation, line):
    # print global value.
    name = line.split(".")[1]
    var_name = name[: name.find("[")] if "[" in name else name

    if var_name not in global_vars:
        stdout.write('Global named "{}" not found.'.format(var_name) + "\n")
        return

    global_type = global_vars[var_name]["type"]
    slot = None
    size = global_vars[var_name]["size"]
    is_bytelike = global_type.startswith("bytes") or global_type.startswith("string")

    if global_type in base_types or is_bytelike:
        slot = global_vars[var_name]["position"]
    elif global_type.startswith("map") and valid_subscript(name, global_type):
        keys = get_keys(name)
        var_pos = global_vars[var_name]["position"]
        slot = get_hash(var_pos, keys, global_type)

    if slot is not None:
        if is_bytelike:
            value = b""
            base_slot_hash = big_endian_to_int(
                keccak(int_to_big_endian(slot).rjust(32, b"\0"))
            )
            len_val = computation.state.account_db.get_storage(
                address=computation.msg.storage_address, slot=base_slot_hash,
            )
            for i in range(0, ceil32(len_val) // 32):
                sub_slot = base_slot_hash + 1 + i
                value += int_to_big_endian(
                    computation.state.account_db.get_storage(
                        address=computation.msg.storage_address, slot=sub_slot,
                    )
                )
            value = value[:len_val]
        else:
            value = computation.state.account_db.get_storage(
                address=computation.msg.storage_address, slot=slot,
            )
        if global_type.startswith("map"):
            global_type = global_type[
                global_type.rfind(",") + 1 : global_type.rfind(")")
            ].strip()
        print_var(stdout, value, global_type)
    else:
        stdout.write('Can not read global of type "{}".\n'.format(global_type))
예제 #6
0
def get_size_of_type(typ):
    if isinstance(typ, BaseType):
        return 1
    elif isinstance(typ, ByteArrayLike):
        # 1 word for offset (in static section), 1 word for length,
        # up to maxlen words for actual data.
        return ceil32(typ.maxlen) // 32 + 2
    elif isinstance(typ, ListType):
        return get_size_of_type(typ.subtype) * typ.count
    elif isinstance(typ, MappingType):
        raise InvalidType("Maps are not supported for function arguments or outputs.")
    elif isinstance(typ, TupleLike):
        return sum([get_size_of_type(v) for v in typ.tuple_members()])
    else:
        raise InvalidType(f"Can not get size of type, Unexpected type: {repr(typ)}")
예제 #7
0
def pack_logging_data(expected_data, args, context, pos):
    # Checks to see if there's any data
    if not args:
        return ['seq'], 0, None, 0
    holder = ['seq']
    maxlen = len(args) * 32  # total size of all packed args (upper limit)

    # Unroll any function calls, to temp variables.
    prealloacted = {}
    for idx, (arg, _expected_arg) in enumerate(zip(args, expected_data)):

        if isinstance(arg, (ast.Str, ast.Call)):
            expr = Expr(arg, context)
            source_lll = expr.lll_node
            typ = source_lll.typ

            if isinstance(arg, ast.Str):
                if len(arg.s) > typ.maxlen:
                    raise TypeMismatchException(
                        "Data input bytes are to big: %r %r" %
                        (len(arg.s), typ), pos)

            tmp_variable = context.new_variable(
                '_log_pack_var_%i_%i' % (arg.lineno, arg.col_offset),
                source_lll.typ,
            )
            tmp_variable_node = LLLnode.from_list(
                tmp_variable,
                typ=source_lll.typ,
                pos=getpos(arg),
                location="memory",
                annotation='log_prealloacted %r' % source_lll.typ,
            )
            # Store len.
            # holder.append(['mstore', len_placeholder, ['mload', unwrap_location(source_lll)]])
            # Copy bytes.

            holder.append(
                make_setter(tmp_variable_node,
                            source_lll,
                            pos=getpos(arg),
                            location='memory'))
            prealloacted[idx] = tmp_variable_node

    requires_dynamic_offset = any(
        [isinstance(data.typ, ByteArrayLike) for data in expected_data])
    if requires_dynamic_offset:
        # Iterator used to zero pad memory.
        zero_pad_i = context.new_placeholder(BaseType('uint256'))
        dynamic_offset_counter = context.new_placeholder(BaseType(32))
        dynamic_placeholder = context.new_placeholder(BaseType(32))
    else:
        dynamic_offset_counter = None
        zero_pad_i = None

    # Create placeholder for static args. Note: order of new_*() is important.
    placeholder_map = {}
    for i, (_arg, data) in enumerate(zip(args, expected_data)):
        typ = data.typ
        if not isinstance(typ, ByteArrayLike):
            placeholder = context.new_placeholder(typ)
        else:
            placeholder = context.new_placeholder(BaseType(32))
        placeholder_map[i] = placeholder

    # Populate static placeholders.
    for i, (arg, data) in enumerate(zip(args, expected_data)):
        typ = data.typ
        placeholder = placeholder_map[i]
        if not isinstance(typ, ByteArrayLike):
            holder, maxlen = pack_args_by_32(
                holder,
                maxlen,
                prealloacted.get(i, arg),
                typ,
                context,
                placeholder,
                zero_pad_i=zero_pad_i,
                pos=pos,
            )

    # Dynamic position starts right after the static args.
    if requires_dynamic_offset:
        holder.append(
            LLLnode.from_list(['mstore', dynamic_offset_counter, maxlen]))

    # Calculate maximum dynamic offset placeholders, used for gas estimation.
    for _arg, data in zip(args, expected_data):
        typ = data.typ
        if isinstance(typ, ByteArrayLike):
            maxlen += 32 + ceil32(typ.maxlen)

    if requires_dynamic_offset:
        datamem_start = dynamic_placeholder + 32
    else:
        datamem_start = placeholder_map[0]

    # Copy necessary data into allocated dynamic section.
    for i, (arg, data) in enumerate(zip(args, expected_data)):
        typ = data.typ
        if isinstance(typ, ByteArrayLike):
            pack_args_by_32(holder=holder,
                            maxlen=maxlen,
                            arg=prealloacted.get(i, arg),
                            typ=typ,
                            context=context,
                            placeholder=placeholder_map[i],
                            datamem_start=datamem_start,
                            dynamic_offset_counter=dynamic_offset_counter,
                            zero_pad_i=zero_pad_i,
                            pos=pos)

    return holder, maxlen, dynamic_offset_counter, datamem_start
예제 #8
0
파일: abi.py 프로젝트: sambacha/vyper-xcode
 def dynamic_size_bound(self):
     # length word + data
     return 32 + ceil32(self.bytes_bound)
예제 #9
0
def pack_args_by_32(holder,
                    maxlen,
                    arg,
                    typ,
                    context,
                    placeholder,
                    dynamic_offset_counter=None,
                    datamem_start=None,
                    zero_pad_i=None,
                    pos=None):
    """
    Copy necessary variables to pre-allocated memory section.

    :param holder: Complete holder for all args
    :param maxlen: Total length in bytes of the full arg section (static + dynamic).
    :param arg: Current arg to pack
    :param context: Context of arg
    :param placeholder: Static placeholder for static argument part.
    :param dynamic_offset_counter: position counter stored in static args.
    :param dynamic_placeholder: pointer to current position in memory to write dynamic values to.
    :param datamem_start: position where the whole datemem section starts.
    """

    if isinstance(typ, BaseType):
        if isinstance(arg, LLLnode):
            value = unwrap_location(arg)
        else:
            value = Expr(arg, context).lll_node
            value = base_type_conversion(value, value.typ, typ, pos)
        holder.append(
            LLLnode.from_list(['mstore', placeholder, value],
                              typ=typ,
                              location='memory'))
    elif isinstance(typ, ByteArrayLike):

        if isinstance(arg, LLLnode):  # Is prealloacted variable.
            source_lll = arg
        else:
            source_lll = Expr(arg, context).lll_node

        # Set static offset, in arg slot.
        holder.append(
            LLLnode.from_list(
                ['mstore', placeholder, ['mload', dynamic_offset_counter]]))
        # Get the biginning to write the ByteArray to.
        dest_placeholder = LLLnode.from_list(
            ['add', datamem_start, ['mload', dynamic_offset_counter]],
            typ=typ,
            location='memory',
            annotation="pack_args_by_32:dest_placeholder")
        copier = make_byte_array_copier(dest_placeholder, source_lll, pos=pos)
        holder.append(copier)
        # Add zero padding.
        new_maxlen = ceil32(source_lll.typ.maxlen)

        holder.append([
            'with',
            '_ceil32_end',
            ['ceil32', ['mload', dest_placeholder]],
            [
                'seq',
                [
                    'with',
                    '_bytearray_loc',
                    dest_placeholder,
                    [
                        'seq',
                        [
                            'repeat',
                            zero_pad_i,
                            ['mload', '_bytearray_loc'],
                            new_maxlen,
                            [
                                'seq',
                                # stay within allocated bounds
                                [
                                    'if',
                                    [
                                        'ge', ['mload', zero_pad_i],
                                        '_ceil32_end'
                                    ], 'break'
                                ],
                                [
                                    'mstore8',
                                    [
                                        'add', ['add', '_bytearray_loc', 32],
                                        ['mload', zero_pad_i]
                                    ],
                                    0,
                                ],
                            ]
                        ],
                    ]
                ],
            ]
        ])

        # Increment offset counter.
        increment_counter = LLLnode.from_list(
            [
                'mstore',
                dynamic_offset_counter,
                [
                    'add',
                    [
                        'add', ['mload', dynamic_offset_counter],
                        ['ceil32', ['mload', dest_placeholder]]
                    ],
                    32,
                ],
            ],
            annotation='Increment dynamic offset counter')
        holder.append(increment_counter)
    elif isinstance(typ, ListType):
        maxlen += (typ.count - 1) * 32
        typ = typ.subtype

        def check_list_type_match(provided):  # Check list types match.
            if provided != typ:
                raise TypeMismatchException(
                    "Log list type '%s' does not match provided, expected '%s'"
                    % (provided, typ))

        # NOTE: Below code could be refactored into iterators/getter functions for each type of
        #       repetitive loop. But seeing how each one is a unique for loop, and in which way
        #       the sub value makes the difference in each type of list clearer.

        # List from storage
        if isinstance(arg, ast.Attribute) and arg.value.id == 'self':
            stor_list = context.globals[arg.attr]
            check_list_type_match(stor_list.typ.subtype)
            size = stor_list.typ.count
            mem_offset = 0
            for i in range(0, size):
                storage_offset = i
                arg2 = LLLnode.from_list(
                    [
                        'sload',
                        [
                            'add', ['sha3_32',
                                    Expr(arg, context).lll_node],
                            storage_offset
                        ]
                    ],
                    typ=typ,
                )
                holder, maxlen = pack_args_by_32(
                    holder,
                    maxlen,
                    arg2,
                    typ,
                    context,
                    placeholder + mem_offset,
                    pos=pos,
                )
                mem_offset += get_size_of_type(typ) * 32

        # List from variable.
        elif isinstance(arg, ast.Name):
            size = context.vars[arg.id].size
            pos = context.vars[arg.id].pos
            check_list_type_match(context.vars[arg.id].typ.subtype)
            mem_offset = 0
            for _ in range(0, size):
                arg2 = LLLnode.from_list(
                    pos + mem_offset,
                    typ=typ,
                    location=context.vars[arg.id].location)
                holder, maxlen = pack_args_by_32(
                    holder,
                    maxlen,
                    arg2,
                    typ,
                    context,
                    placeholder + mem_offset,
                    pos=pos,
                )
                mem_offset += get_size_of_type(typ) * 32

        # List from list literal.
        else:
            mem_offset = 0
            for arg2 in arg.elts:
                holder, maxlen = pack_args_by_32(
                    holder,
                    maxlen,
                    arg2,
                    typ,
                    context,
                    placeholder + mem_offset,
                    pos=pos,
                )
                mem_offset += get_size_of_type(typ) * 32
    return holder, maxlen
예제 #10
0
 def memory_bytes_required(self):
     return ceil32(self.maxlen) + 32 * DYNAMIC_ARRAY_OVERHEAD
예제 #11
0
    def from_declaration(cls, code, global_ctx):
        name = code.target.id
        pos = 0

        check_valid_varname(name,
                            global_ctx._custom_units,
                            global_ctx._structs,
                            global_ctx._constants,
                            pos=code,
                            error_prefix="Event name invalid. ",
                            exc=EventDeclarationException)

        # Determine the arguments, expects something of the form def foo(arg1: num, arg2: num ...
        args = []
        indexed_list = []
        topics_count = 1
        if code.annotation.args:
            keys = code.annotation.args[0].keys
            values = code.annotation.args[0].values
            for i in range(len(keys)):
                typ = values[i]
                if not isinstance(keys[i], ast.Name):
                    raise EventDeclarationException(
                        'Invalid key type, expected a valid name.',
                        keys[i],
                    )
                if not isinstance(typ, (ast.Name, ast.Call, ast.Subscript)):
                    raise EventDeclarationException(
                        'Invalid event argument type.', typ)
                if isinstance(typ,
                              ast.Call) and not isinstance(typ.func, ast.Name):
                    raise EventDeclarationException(
                        'Invalid event argument type', typ)
                arg = keys[i].id
                arg_item = keys[i]
                is_indexed = False

                # Check to see if argument is a topic
                if isinstance(typ, ast.Call) and typ.func.id == 'indexed':
                    typ = values[i].args[0]
                    indexed_list.append(True)
                    topics_count += 1
                    is_indexed = True
                else:
                    indexed_list.append(False)
                if isinstance(typ, ast.Subscript) and getattr(
                        typ.value, 'id', None
                ) == 'bytes' and typ.slice.value.n > 32 and is_indexed:  # noqa: E501
                    raise EventDeclarationException(
                        "Indexed arguments are limited to 32 bytes")
                if topics_count > 4:
                    raise EventDeclarationException(
                        f"Maximum of 3 topics {topics_count - 1} given",
                        arg,
                    )
                if not isinstance(arg, str):
                    raise VariableDeclarationException("Argument name invalid",
                                                       arg)
                if not typ:
                    raise InvalidTypeException("Argument must have type", arg)
                check_valid_varname(
                    arg,
                    global_ctx._custom_units,
                    global_ctx._structs,
                    global_ctx._constants,
                    pos=arg_item,
                    error_prefix="Event argument name invalid or reserved.",
                )
                if arg in (x.name for x in args):
                    raise VariableDeclarationException(
                        "Duplicate function argument name: " + arg,
                        arg_item,
                    )
                # Can struct be logged?
                parsed_type = global_ctx.parse_type(typ, None)
                args.append(VariableRecord(arg, pos, parsed_type, False))
                if isinstance(parsed_type, ByteArrayType):
                    pos += ceil32(typ.slice.value.n)
                else:
                    pos += get_size_of_type(parsed_type) * 32
        sig = name + '(' + ','.join([
            canonicalize_type(arg.typ, indexed_list[pos])
            for pos, arg in enumerate(args)
        ]) + ')'  # noqa F812
        event_id = bytes_to_int(keccak256(bytes(sig, 'utf-8')))
        return cls(name, args, indexed_list, event_id, sig)
예제 #12
0
def pack_logging_data(expected_data, args, context, pos):
    # Checks to see if there's any data
    if not args:
        return ["seq"], 0, None, 0
    holder = ["seq"]
    maxlen = len(args) * 32  # total size of all packed args (upper limit)

    # Unroll any function calls, to temp variables.
    prealloacted = {}
    for idx, (arg, _expected_arg) in enumerate(zip(args, expected_data)):

        if isinstance(arg, (vy_ast.Str, vy_ast.Call)) and arg.get("func.id") != "empty":
            expr = Expr(arg, context)
            source_lll = expr.lll_node
            typ = source_lll.typ

            if isinstance(arg, vy_ast.Str):
                if len(arg.s) > typ.maxlen:
                    raise TypeMismatch(f"Data input bytes are to big: {len(arg.s)} {typ}", pos)

            tmp_variable = context.new_internal_variable(source_lll.typ)
            tmp_variable_node = LLLnode.from_list(
                tmp_variable,
                typ=source_lll.typ,
                pos=getpos(arg),
                location="memory",
                annotation=f"log_prealloacted {source_lll.typ}",
            )
            # Store len.
            # holder.append(['mstore', len_placeholder, ['mload', unwrap_location(source_lll)]])
            # Copy bytes.

            holder.append(
                make_setter(tmp_variable_node, source_lll, pos=getpos(arg), location="memory")
            )
            prealloacted[idx] = tmp_variable_node

    # Create internal variables for for dynamic and static args.
    static_types = []
    for data in expected_data:
        static_types.append(data.typ if not isinstance(data.typ, ByteArrayLike) else BaseType(32))

    requires_dynamic_offset = any(isinstance(data.typ, ByteArrayLike) for data in expected_data)

    dynamic_offset_counter = None
    if requires_dynamic_offset:
        dynamic_offset_counter = context.new_internal_variable(BaseType(32))
        dynamic_placeholder = context.new_internal_variable(BaseType(32))

    static_vars = [context.new_internal_variable(i) for i in static_types]

    # Populate static placeholders.
    for i, (arg, data) in enumerate(zip(args, expected_data)):
        typ = data.typ
        placeholder = static_vars[i]
        if not isinstance(typ, ByteArrayLike):
            holder, maxlen = pack_args_by_32(
                holder, maxlen, prealloacted.get(i, arg), typ, context, placeholder, pos=pos,
            )

    # Dynamic position starts right after the static args.
    if requires_dynamic_offset:
        holder.append(LLLnode.from_list(["mstore", dynamic_offset_counter, maxlen]))

    # Calculate maximum dynamic offset placeholders, used for gas estimation.
    for _arg, data in zip(args, expected_data):
        typ = data.typ
        if isinstance(typ, ByteArrayLike):
            maxlen += 32 + ceil32(typ.maxlen)

    if requires_dynamic_offset:
        datamem_start = dynamic_placeholder + 32
    else:
        datamem_start = static_vars[0]

    # Copy necessary data into allocated dynamic section.
    for i, (arg, data) in enumerate(zip(args, expected_data)):
        typ = data.typ
        if isinstance(typ, ByteArrayLike):
            if isinstance(arg, vy_ast.Call) and arg.func.get("id") == "empty":
                # TODO add support for this
                raise StructureException(
                    "Cannot use `empty` on Bytes or String types within an event log", arg
                )
            pack_args_by_32(
                holder=holder,
                maxlen=maxlen,
                arg=prealloacted.get(i, arg),
                typ=typ,
                context=context,
                placeholder=static_vars[i],
                datamem_start=datamem_start,
                dynamic_offset_counter=dynamic_offset_counter,
                pos=pos,
            )

    return holder, maxlen, dynamic_offset_counter, datamem_start
예제 #13
0
def apply_general_optimizations(node: LLLnode) -> LLLnode:
    # TODO add rules for modulus powers of 2
    # TODO refactor this into several functions

    argz = [apply_general_optimizations(arg) for arg in node.args]

    value = node.value
    typ = node.typ
    location = node.location
    pos = node.pos
    annotation = node.annotation
    add_gas_estimate = node.add_gas_estimate
    valency = node.valency

    if node.value == "seq":
        _merge_memzero(argz)
        _merge_calldataload(argz)

    if node.value in arith and int_at(argz, 0) and int_at(argz, 1):
        # compile-time arithmetic
        left, right = get_int_at(argz, 0), get_int_at(argz, 1)
        # `node.value in arith` implies that `node.value` is a `str`
        fn, symb = arith[str(node.value)]
        value = fn(left, right)
        if argz[0].annotation and argz[1].annotation:
            annotation = argz[0].annotation + symb + argz[1].annotation
        elif argz[0].annotation or argz[1].annotation:
            annotation = ((argz[0].annotation or str(left)) + symb +
                          (argz[1].annotation or str(right)))
        else:
            annotation = ""

        argz = []

    elif node.value == "ceil32" and int_at(argz, 0):
        t = argz[0]
        annotation = f"ceil32({t.value})"
        argz = []
        value = ceil32(t.value)

    elif node.value == "add" and get_int_at(argz, 0) == 0:
        value = argz[1].value
        annotation = argz[1].annotation
        argz = argz[1].args

    elif node.value == "add" and get_int_at(argz, 1) == 0:
        value = argz[0].value
        annotation = argz[0].annotation
        argz = argz[0].args

    elif (node.value in ("clamp", "uclamp") and int_at(argz, 0)
          and int_at(argz, 1) and int_at(argz, 2)):
        if get_int_at(argz, 0, True) > get_int_at(argz, 1,
                                                  True):  # type: ignore
            raise Exception("Clamp always fails")
        elif get_int_at(argz, 1, True) > get_int_at(argz, 2,
                                                    True):  # type: ignore
            raise Exception("Clamp always fails")
        else:
            return argz[1]

    elif node.value in ("clamp", "uclamp") and int_at(argz, 0) and int_at(
            argz, 1):
        if get_int_at(argz, 0, True) > get_int_at(argz, 1,
                                                  True):  # type: ignore
            raise Exception("Clamp always fails")
        else:
            # i.e., clample or uclample
            value += "le"  # type: ignore
            argz = [argz[1], argz[2]]

    elif node.value == "uclamplt" and int_at(argz, 0) and int_at(argz, 1):
        if get_int_at(argz, 0, True) >= get_int_at(argz, 1,
                                                   True):  # type: ignore
            raise Exception("Clamp always fails")
        value = argz[0].value
        argz = []

    elif node.value == "clamp_nonzero" and int_at(argz, 0):
        if get_int_at(argz, 0) != 0:
            value = argz[0].value
            argz = []
        else:
            raise Exception("Clamp always fails")

    # TODO: (uclampgt 0 x) -> (iszero (iszero x))
    # TODO: more clamp rules

    # [eq, x, 0] is the same as [iszero, x].
    # TODO handle (ne 0 x) as well
    elif node.value == "eq" and int_at(argz, 1) and argz[1].value == 0:
        value = "iszero"
        argz = [argz[0]]

    # TODO handle (ne -1 x) as well
    elif node.value == "eq" and int_at(argz, 1) and argz[1].value == -1:
        value = "iszero"
        argz = [LLLnode.from_list(["not", argz[0]])]

    # (eq x y) has the same truthyness as (iszero (xor x y))
    # rewrite 'eq' as 'xor' in places where truthy is accepted.
    # (the sequence (if (iszero (xor x y))) will be translated to
    #  XOR ISZERO ISZERO ..JUMPI and the ISZERO ISZERO will be
    #  optimized out)
    elif node.value in ("if", "assert") and argz[0].value == "eq":
        argz[0] = ["iszero", ["xor", *argz[0].args]]  # type: ignore

    elif node.value == "if" and len(argz) == 3:
        # if(x) compiles to jumpi(_, iszero(x))
        # there is an asm optimization for the sequence ISZERO ISZERO..JUMPI
        # so we swap the branches here to activate that optimization.
        cond = argz[0]
        true_branch = argz[1]
        false_branch = argz[2]
        contra_cond = LLLnode.from_list(["iszero", cond])

        argz = [contra_cond, false_branch, true_branch]

    ret = LLLnode.from_list(
        [value, *argz],
        typ=typ,
        location=location,
        pos=pos,
        annotation=annotation,
        add_gas_estimate=add_gas_estimate,
        valency=valency,
    )
    if node.total_gas is not None:
        ret.total_gas = node.total_gas - node.gas + ret.gas
        ret.func_name = node.func_name

    return ret
예제 #14
0
def _optimize(node: IRnode, parent: Optional[IRnode]) -> Tuple[bool, IRnode]:
    starting_symbols = node.unique_symbols

    res = [_optimize(arg, node) for arg in node.args]
    argz: list
    if len(res) == 0:
        args_changed, argz = False, []
    else:
        changed_flags, argz = zip(*res)  # type: ignore
        args_changed = any(changed_flags)
        argz = list(argz)

    value = node.value
    typ = node.typ
    location = node.location
    source_pos = node.source_pos
    error_msg = node.error_msg
    annotation = node.annotation
    add_gas_estimate = node.add_gas_estimate

    changed = False

    # in general, we cannot enforce the symbols check. for instance,
    # the dead branch eliminator will almost always trip the symbols check.
    # but for certain operations, particularly binops, we want to do the check.
    should_check_symbols = False

    def finalize(val, args):
        if not changed and not args_changed:
            # skip IRnode.from_list, which may be (compile-time) expensive
            return False, node

        ir_builder = [val, *args]
        ret = IRnode.from_list(
            ir_builder,
            typ=typ,
            location=location,
            source_pos=source_pos,
            error_msg=error_msg,
            annotation=annotation,
            add_gas_estimate=add_gas_estimate,
        )

        if should_check_symbols:
            _check_symbols(starting_symbols, ret)

        _, ret = _optimize(ret, parent)
        return True, ret

    if value == "seq":
        changed |= _merge_memzero(argz)
        changed |= _merge_calldataload(argz)
        changed |= _remove_empty_seqs(argz)

        # (seq x) => (x) for cleanliness and
        # to avoid blocking other optimizations
        if len(argz) == 1:
            return True, _optimize(argz[0], parent)[1]

        return finalize(value, argz)

    if value in arith:
        parent_op = parent.value if parent is not None else None

        res = _optimize_binop(value, argz, annotation, parent_op)
        if res is not None:
            changed = True
            should_check_symbols = True
            value, argz, annotation = res  # type: ignore
            return finalize(value, argz)

    ###
    # BITWISE OPS
    ###

    # note, don't optimize these too much as these kinds of expressions
    # may be hand optimized for codesize. we can optimize bitwise ops
    # more, once we have a pipeline which optimizes for codesize.
    if value in ("shl", "shr", "sar") and argz[0].value == 0:
        # x >> 0 == x << 0 == x
        changed = True
        annotation = argz[1].annotation
        return finalize(argz[1].value, argz[1].args)

    if node.value == "ceil32" and _is_int(argz[0]):
        changed = True
        annotation = f"ceil32({argz[0].value})"
        return finalize(ceil32(argz[0].value), [])

    if value == "iszero" and _is_int(argz[0]):
        changed = True
        val = int(argz[0].value == 0)  # int(bool) == 1 if bool else 0
        return finalize(val, [])

    if node.value == "if":
        # optimize out the branch
        if _is_int(argz[0]):
            changed = True
            # if false
            if _evm_int(argz[0]) == 0:
                # return the else branch (or [] if there is no else)
                return finalize("seq", argz[2:])
            # if true
            else:
                # return the first branch
                return finalize("seq", [argz[1]])

        elif len(argz) == 3 and argz[0].value != "iszero":
            # if(x) compiles to jumpi(_, iszero(x))
            # there is an asm optimization for the sequence ISZERO ISZERO..JUMPI
            # so we swap the branches here to activate that optimization.
            cond = argz[0]
            true_branch = argz[1]
            false_branch = argz[2]
            contra_cond = IRnode.from_list(["iszero", cond])

            argz = [contra_cond, false_branch, true_branch]
            changed = True
            return finalize("if", argz)

    if value in ("assert", "assert_unreachable") and _is_int(argz[0]):
        if _evm_int(argz[0]) == 0:
            raise StaticAssertionException(
                f"assertion found to fail at compile time. (hint: did you mean `raise`?) {node}",
                source_pos,
            )
        else:
            changed = True
            return finalize("seq", [])

    return finalize(value, argz)