Пример #1
0
def abi_type_of(lll_typ):
    if isinstance(lll_typ, BaseType):
        t = lll_typ.typ
        if 'uint256' == t:
            return ABI_GIntM(256, False)
        elif 'int128' == t:
            return ABI_GIntM(128, True)
        elif 'address' == t:
            return ABI_Address()
        elif 'bytes32' == t:
            return ABI_BytesM(32)
        elif 'bool' == t:
            return ABI_Bool()
        elif 'decimal' == t:
            return ABI_FixedMxN(168, 10, True)
        else:
            raise CompilerPanic(f'Unrecognized type {t}')
    elif isinstance(lll_typ, TupleLike):
        return ABI_Tuple([abi_type_of(t) for t in lll_typ.tuple_members()])
    elif isinstance(lll_typ, ListType):
        return ABI_StaticArray(abi_type_of(lll_typ.subtype), lll_typ.count)
    elif isinstance(lll_typ, ByteArrayType):
        return ABI_Bytes(lll_typ.maxlen)
    elif isinstance(lll_typ, StringType):
        return ABI_String(lll_typ.maxlen)
    else:
        raise CompilerPanic(f'Unrecognized type {lll_typ}')
Пример #2
0
    def __init__(self, m_bits, n_places, signed):
        if not (0 < m_bits <= 256 and 0 == m_bits % 8):
            raise CompilerPanic('Invalid M for FixedMxN')
        if not (0 < n_places and n_places <= 80):
            raise CompilerPanic('Invalid N for FixedMxN')

        self.m_bits = m_bits
        self.n_places = n_places
        self.signed = signed
Пример #3
0
    def replace_in_tree(self, old_node: srilangNode,
                        new_node: srilangNode) -> None:
        """
        Perform an in-place substitution of a node within the tree.

        Parameters
        ----------
        old_node : srilangNode
            Node object to be replaced. If the node does not currently exist
            within the AST, a `CompilerPanic` is raised.
        new_node : srilangNode
            Node object to replace new_node.

        Returns
        -------
        None
        """
        parent = old_node._parent
        if old_node not in self.get_descendants(type(old_node)):
            raise CompilerPanic(
                "Node to be replaced does not exist within the tree")

        if old_node not in parent._children:
            raise CompilerPanic(
                "Node to be replaced does not exist within parent children")

        is_replaced = False
        for key in parent.get_fields():
            obj = getattr(parent, key, None)
            if obj == old_node:
                if is_replaced:
                    raise CompilerPanic(
                        "Node to be replaced exists as multiple members in parent"
                    )
                setattr(parent, key, new_node)
                is_replaced = True
            elif isinstance(obj, list) and obj.count(old_node):
                if is_replaced or obj.count(old_node) > 1:
                    raise CompilerPanic(
                        "Node to be replaced exists as multiple members in parent"
                    )
                obj[obj.index(old_node)] = new_node
                is_replaced = True
        if not is_replaced:
            raise CompilerPanic(
                "Node to be replaced does not exist within parent members")

        parent._children.remove(old_node)

        new_node._parent = parent
        new_node._depth = old_node._depth
        parent._children.add(new_node)
Пример #4
0
 def increase_memory(self, size: int) -> Tuple[int, int]:
     if size % 32 != 0:
         raise CompilerPanic(
             'Memory misaligment, only multiples of 32 supported.')
     before_value = self.next_mem
     self.next_mem += size
     return before_value, self.next_mem
Пример #5
0
def lazy_abi_decode(typ, src, pos=None):
    if isinstance(typ, (ListType, TupleLike)):
        if isinstance(typ, TupleLike):
            ts = typ.tuple_members()
        else:
            ts = [typ.subtyp for _ in range(typ.count)]
        ofst = 0
        os = []
        for t in ts:
            child_abi_t = abi_type_of(t)
            loc = _add_ofst(src, ofst)
            if child_abi_t.is_dynamic():
                # load the offset word, which is the
                # (location-independent) offset from the start of the
                # src buffer.
                dyn_ofst = unwrap_location(ofst)
                loc = _add_ofst(src, dyn_ofst)
            os.append(lazy_abi_decode(t, loc, pos))
            ofst += child_abi_t.embedded_static_size()

        return LLLnode.from_list(['multi'] + os, typ=typ, pos=pos)

    elif isinstance(typ, (BaseType, ByteArrayLike)):
        return unwrap_location(src)
    else:
        raise CompilerPanic(f'unknown type for lazy_abi_decode {typ}')
Пример #6
0
def validate_call_args(node: sri_ast.Call,
                       arg_count: Union[int, tuple],
                       kwargs: Optional[list] = None) -> None:
    """
    Validate positional and keyword arguments of a Call node.

    This function does not handle type checking of arguments, it only checks
    correctness of the number of arguments given and keyword names.

    Arguments
    ---------
    node : Call
        srilang ast Call node to be validated.
    arg_count : int | tuple
        The required number of positional arguments. When given as a tuple the
        value is interpreted as the minimum and maximum number of arguments.
    kwargs : list, optional
        A list of valid keyword arguments. When arg_count is a tuple and the
        number of positional arguments exceeds the minimum, the excess values are
        considered to fill the first values on this list.

    Returns
    -------
        None. Raises an exception when the arguments are invalid.
    """
    if kwargs is None:
        kwargs = []
    if not isinstance(node, sri_ast.Call):
        raise StructureException("Expected Call", node)
    if not isinstance(arg_count, (int, tuple)):
        raise CompilerPanic(
            f"Invalid type for arg_count: {type(arg_count).__name__}")

    if isinstance(arg_count, int) and len(node.args) != arg_count:
        raise ArgumentException(
            f"Invalid argument count: expected {arg_count}, got {len(node.args)}",
            node)
    elif (isinstance(arg_count, tuple)
          and not arg_count[0] <= len(node.args) <= arg_count[1]):
        raise ArgumentException(
            f"Invalid argument count: expected between "
            f"{arg_count[0]} and {arg_count[1]}, got {len(node.args)}",
            node,
        )

    if not kwargs and node.keywords:
        raise ArgumentException("Keyword arguments are not accepted here",
                                node.keywords[0])
    for key in node.keywords:
        if key.arg is None:
            raise StructureException("Use of **kwargs is not supported",
                                     key.value)
        if key.arg not in kwargs:
            raise ArgumentException(f"Invalid keyword argument '{key.arg}'",
                                    key)
        if (isinstance(arg_count, tuple)
                and kwargs.index(key.arg) < len(node.args) - arg_count[0]):
            raise ArgumentException(
                f"'{key.arg}' was given as a positional argument", key)
Пример #7
0
def dict_to_ast(
        ast_struct: Union[Dict, List]) -> Union[sri_ast.srilangNode, List]:
    """
    Converts an AST dict, or list of dicts, into srilang AST node objects.
    """
    if isinstance(ast_struct, dict):
        return sri_ast.get_node(ast_struct)
    if isinstance(ast_struct, list):
        return [sri_ast.get_node(i) for i in ast_struct]
    raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".')
Пример #8
0
 def __init__(self,
              typ,
              unit=False,
              positional=False,
              override_signature=False,
              is_literal=False):
     self.typ = typ
     if unit or positional:
         raise CompilerPanic("Units are no longer supported")
     self.override_signature = override_signature
     self.is_literal = is_literal
Пример #9
0
def version_check(begin: Optional[str] = None, end: Optional[str] = None) -> bool:
    if begin is None and end is None:
        raise CompilerPanic("Either beginning or end fork ruleset must be set.")
    if begin is None:
        begin_idx = min(EVM_VERSIONS.values())
    else:
        begin_idx = EVM_VERSIONS[begin]
    if end is None:
        end_idx = max(EVM_VERSIONS.values())
    else:
        end_idx = EVM_VERSIONS[end]
    return begin_idx <= active_evm_version <= end_idx
Пример #10
0
def ast_to_dict(
        ast_struct: Union[sri_ast.srilangNode, List]) -> Union[Dict, List]:
    """
    Converts a srilang AST node, or list of nodes, into a dictionary suitable for
    output to the user.
    """
    if isinstance(ast_struct, sri_ast.srilangNode):
        return ast_struct.to_dict()
    elif isinstance(ast_struct, list):
        return [i.to_dict() for i in ast_struct]
    else:
        raise CompilerPanic(
            f'Unknown srilang AST node provided: "{type(ast_struct)}".')
Пример #11
0
    def visit_Num(self, node):
        """
        Adjust numeric node class based on the value type.

        Python uses `Num` to represent floats and integers. Integers may also
        be given in binary, octal, decimal, or hexadecimal format. This method
        modifies `ast_type` to seperate `Num` into more granular srilang node
        classes.
        """
        # modify srilang AST type according to the format of the literal value
        self.generic_visit(node)
        value = node.node_source_code

        # deduce non base-10 types based on prefix
        literal_prefixes = {"0x": "Hex", "0o": "Octal"}
        if value.lower()[:2] in literal_prefixes:
            node.ast_type = literal_prefixes[value.lower()[:2]]
            node.n = value

        elif value.lower()[:2] == "0b":
            node.ast_type = "Bytes"
            mod = (len(value) - 2) % 8
            if mod:
                raise SyntaxException(
                    f"Bit notation requires a multiple of 8 bits. {8-mod} bit(s) are missing.",
                    self._source_code,
                    node.lineno,
                    node.col_offset,
                )
            node.value = int(value, 2).to_bytes(len(value) // 8, "big")

        elif isinstance(node.n, float):
            node.ast_type = "Decimal"
            node.n = Decimal(value)

        elif isinstance(node.n, int):
            node.ast_type = "Int"

        else:
            raise CompilerPanic(
                f"Unexpected type for Constant value: {type(node.n).__name__}")

        return node
Пример #12
0
    def __init__(self,
                 value: Union[str, int],
                 args: List['LLLnode'] = None,
                 typ: 'BaseType' = None,
                 location: str = None,
                 pos: Optional[Tuple[int, int]] = None,
                 annotation: Optional[str] = None,
                 mutable: bool = True,
                 add_gas_estimate: int = 0,
                 valency: Optional[int] = None):
        if args is None:
            args = []

        self.value = value
        self.args = args
        self.typ = typ
        assert isinstance(self.typ, NodeType) or self.typ is None, repr(
            self.typ)
        self.location = location
        self.pos = pos
        self.annotation = annotation
        self.mutable = mutable
        self.add_gas_estimate = add_gas_estimate
        self.as_hex = AS_HEX_DEFAULT

        # Optional annotation properties for gas estimation
        self.total_gas = None
        self.func_name = None

        # Determine this node's valency (1 if it pushes a value on the stack,
        # 0 otherwise) and checks to make sure the number and valencies of
        # children are correct. Also, find an upper bound on gas consumption
        # Numbers
        if isinstance(self.value, int):
            self.valency = 1
            self.gas = 5
        elif isinstance(self.value, str):
            # Opcodes and pseudo-opcodes (e.g. clamp)
            if self.value.upper() in get_comb_opcodes():
                _, ins, outs, gas = get_comb_opcodes()[self.value.upper()]
                self.valency = outs
                if len(self.args) != ins:
                    raise CompilerPanic(
                        f"Number of arguments mismatched: {self.value} {self.args}"
                    )
                # We add 2 per stack height at push time and take it back
                # at pop time; this makes `break` easier to handle
                self.gas = gas + 2 * (outs - ins)
                for arg in self.args:
                    # pop and pass are used to push/pop values on the stack to be
                    # consumed for private functions, therefore we whitelist this as a zero valency
                    # allowed argument.
                    zero_valency_whitelist = {'pass', 'pop'}
                    if arg.valency == 0 and arg.value not in zero_valency_whitelist:
                        raise CompilerPanic(
                            "Can't have a zerovalent argument to an opcode or a pseudo-opcode! "
                            f"{arg.value}: {arg}. Please file a bug report.")
                    self.gas += arg.gas
                # Dynamic gas cost: 8 gas for each byte of logging data
                if self.value.upper()[0:3] == 'LOG' and isinstance(
                        self.args[1].value, int):
                    self.gas += self.args[1].value * 8
                # Dynamic gas cost: non-zero-valued call
                if self.value.upper() == 'CALL' and self.args[2].value != 0:
                    self.gas += 34000
                # Dynamic gas cost: filling sstore (ie. not clearing)
                elif self.value.upper(
                ) == 'SSTORE' and self.args[1].value != 0:
                    self.gas += 15000
                # Dynamic gas cost: calldatacopy
                elif self.value.upper() in ('CALLDATACOPY', 'CODECOPY'):
                    size = 34000
                    if isinstance(self.args[2].value, int):
                        size = self.args[2].value
                    self.gas += ceil32(size) // 32 * 3
                # Gas limits in call
                if self.value.upper() == 'CALL' and isinstance(
                        self.args[0].value, int):
                    self.gas += self.args[0].value
            # If statements
            elif self.value == 'if':
                if len(self.args) == 3:
                    self.gas = self.args[0].gas + max(self.args[1].gas,
                                                      self.args[2].gas) + 3
                if len(self.args) == 2:
                    self.gas = self.args[0].gas + self.args[1].gas + 17
                if not self.args[0].valency:
                    raise CompilerPanic(
                        "Can't have a zerovalent argument as a test to an if "
                        f"statement! {self.args[0]}")
                if len(self.args) not in (2, 3):
                    raise CompilerPanic("If can only have 2 or 3 arguments")
                self.valency = self.args[1].valency
            # With statements: with <var> <initial> <statement>
            elif self.value == 'with':
                if len(self.args) != 3:
                    raise CompilerPanic("With statement must have 3 arguments")
                if len(self.args[0].args) or not isinstance(
                        self.args[0].value, str):
                    raise CompilerPanic(
                        "First argument to with statement must be a variable")
                if not self.args[1].valency:
                    raise CompilerPanic(
                        ("Second argument to with statement (initial value) "
                         f"cannot be zerovalent: {self.args[1]}"))
                self.valency = self.args[2].valency
                self.gas = sum([arg.gas for arg in self.args]) + 5
            # Repeat statements: repeat <index_memloc> <startval> <rounds> <body>
            elif self.value == 'repeat':
                is_invalid_repeat_count = any((
                    len(self.args[2].args),
                    not isinstance(self.args[2].value, int),
                    isinstance(self.args[2].value, int)
                    and self.args[2].value <= 0,
                ))

                if is_invalid_repeat_count:
                    raise CompilerPanic(
                        ("Number of times repeated must be a constant nonzero "
                         f"positive integer: {self.args[2]}"))
                if not self.args[0].valency:
                    raise CompilerPanic((
                        "First argument to repeat (memory location) cannot be "
                        f"zerovalent: {self.args[0]}"))
                if not self.args[1].valency:
                    raise CompilerPanic(
                        ("Second argument to repeat (start value) cannot be "
                         f"zerovalent: {self.args[1]}"))
                if self.args[3].valency:
                    raise CompilerPanic((
                        "Third argument to repeat (clause to be repeated) must "
                        f"be zerovalent: {self.args[3]}"))
                self.valency = 0
                rounds: int
                if self.args[1].value in ('calldataload', 'mload'
                                          ) or self.args[1].value == 'sload':
                    if isinstance(self.args[2].value, int):
                        rounds = self.args[2].value
                    else:
                        raise CompilerPanic(
                            f'Unsupported rounds argument type. {self.args[2]}'
                        )
                else:
                    if isinstance(self.args[2].value, int) and isinstance(
                            self.args[1].value, int):
                        rounds = abs(self.args[2].value - self.args[1].value)
                    else:
                        raise CompilerPanic(
                            f'Unsupported second argument types. {self.args}')
                self.gas = rounds * (self.args[3].gas + 50) + 30
            # Seq statements: seq <statement> <statement> ...
            elif self.value == 'seq':
                self.valency = self.args[-1].valency if self.args else 0
                self.gas = sum([arg.gas for arg in self.args]) + 30
            # Multi statements: multi <expr> <expr> ...
            elif self.value == 'multi':
                for arg in self.args:
                    if not arg.valency:
                        raise CompilerPanic(
                            f"Multi expects all children to not be zerovalent: {arg}"
                        )
                self.valency = sum([arg.valency for arg in self.args])
                self.gas = sum([arg.gas for arg in self.args])
            # LLL brackets (don't bother gas counting)
            elif self.value == 'lll':
                self.valency = 1
                self.gas = NullAttractor()
            # Stack variables
            else:
                self.valency = 1
                self.gas = 5
                if self.value == 'seq_unchecked':
                    self.gas = sum([arg.gas for arg in self.args]) + 30
                if self.value == 'if_unchecked':
                    self.gas = self.args[0].gas + self.args[1].gas + 17
        elif self.value is None:
            self.valency = 1
            # None LLLnodes always get compiled into something else, e.g.
            # mzero or PUSH1 0, and the gas will get re-estimated then.
            self.gas = 3
        else:
            raise CompilerPanic(
                f"Invalid value for LLL AST node: {self.value}")
        assert isinstance(self.args, list)

        if valency is not None:
            self.valency = valency

        self.gas += self.add_gas_estimate
Пример #13
0
def abi_encode(dst, lll_node, pos=None, bufsz=None, returns=False):
    parent_abi_t = abi_type_of(lll_node.typ)
    size_bound = parent_abi_t.static_size() + parent_abi_t.dynamic_size_bound()
    if bufsz is not None and bufsz < 32 * size_bound:
        raise CompilerPanic('buffer provided to abi_encode not large enough')

    lll_ret = ['seq']
    dyn_ofst = 'dyn_ofst'  # current offset in the dynamic section
    dst_begin = 'dst'  # pointer to beginning of buffer
    dst_loc = 'dst_loc'  # pointer to write location in static section
    os = o_list(lll_node, pos=pos)

    for i, o in enumerate(os):
        abi_t = abi_type_of(o.typ)

        if parent_abi_t.is_tuple():
            if abi_t.is_dynamic():
                lll_ret.append(['mstore', dst_loc, dyn_ofst])
                # recurse
                child_dst = ['add', dst_begin, dyn_ofst]
                child = abi_encode(child_dst, o, pos=pos, returns=True)
                # increment dyn ofst for the return
                # (optimization note:
                #   if non-returning and this is the last dyn member in
                #   the tuple, this set can be elided.)
                lll_ret.append(['set', dyn_ofst, ['add', dyn_ofst, child]])
            else:
                # recurse
                lll_ret.append(abi_encode(dst_loc, o, pos=pos, returns=False))

        elif isinstance(o.typ, BaseType):
            d = LLLnode(dst_loc, typ=o.typ, location='memory')
            lll_ret.append(make_setter(d, o, location=d.location, pos=pos))
        elif isinstance(o.typ, ByteArrayLike):
            d = LLLnode.from_list(dst_loc, typ=o.typ, location='memory')
            lll_ret.append([
                'seq',
                make_setter(d, o, location=d.location, pos=pos),
                zero_pad(d)
            ])
        else:
            raise CompilerPanic(f'unreachable type: {o.typ}')

        if i + 1 == len(os):
            pass  # optimize out the last increment to dst_loc
        else:  # note: always false for non-tuple types
            sz = abi_t.embedded_static_size()
            lll_ret.append(['set', dst_loc, ['add', dst_loc, sz]])

    # declare LLL variables.
    if returns:
        if not parent_abi_t.is_dynamic():
            lll_ret.append(parent_abi_t.embedded_static_size())
        elif parent_abi_t.is_tuple():
            lll_ret.append('dyn_ofst')
        elif isinstance(lll_node.typ, ByteArrayLike):
            # for abi purposes, return zero-padded length
            calc_len = ['ceil32', ['add', 32, ['mload', dst_loc]]]
            lll_ret.append(calc_len)
        else:
            raise CompilerPanic('unknown type {lll_node.typ}')

    if not (parent_abi_t.is_dynamic() and parent_abi_t.is_tuple()):
        pass  # optimize out dyn_ofst allocation if we don't need it
    else:
        dyn_section_start = parent_abi_t.static_size()
        lll_ret = ['with', 'dyn_ofst', dyn_section_start, lll_ret]

    lll_ret = ['with', dst_begin, dst, ['with', dst_loc, dst_begin, lll_ret]]

    return LLLnode.from_list(lll_ret)
Пример #14
0
    def __init__(self, bytes_bound):
        if not bytes_bound >= 0:
            raise CompilerPanic('Negative bytes_bound provided to ABI_Bytes')

        self.bytes_bound = bytes_bound
Пример #15
0
    def __init__(self, subtyp, elems_bound):
        if not elems_bound >= 0:
            raise CompilerPanic('Negative bound provided to DynamicArray')

        self.subtyp = subtyp
        self.elems_bound = elems_bound
Пример #16
0
    def __init__(self, subtyp, m_elems):
        if not m_elems >= 0:
            raise CompilerPanic('Invalid M')

        self.subtyp = subtyp
        self.m_elems = m_elems
Пример #17
0
    def __init__(self, m_bytes):
        if not 0 < m_bytes <= 32:
            raise CompilerPanic('Invalid M for BytesM')

        self.m_bytes = m_bytes
Пример #18
0
def compile_to_assembly(code,
                        withargs=None,
                        existing_labels=None,
                        break_dest=None,
                        height=0):
    if withargs is None:
        withargs = {}
    if not isinstance(withargs, dict):
        raise CompilerPanic(f"Incorrect type for withargs: {type(withargs)}")

    if existing_labels is None:
        existing_labels = set()
    if not isinstance(existing_labels, set):
        raise CompilerPanic(
            f"Incorrect type for existing_labels: {type(existing_labels)}")

    # Opcodes
    if isinstance(code.value, str) and code.value.upper() in get_opcodes():
        o = []
        for i, c in enumerate(code.args[::-1]):
            o.extend(
                compile_to_assembly(c, withargs, existing_labels, break_dest,
                                    height + i))
        o.append(code.value.upper())
        return o
    # Numbers
    elif isinstance(code.value, int):
        if code.value <= -2**255:
            raise Exception(f"Value too low: {code.value}")
        elif code.value >= 2**256:
            raise Exception(f"Value too high: {code.value}")
        bytez = num_to_bytearray(code.value % 2**256) or [0]
        return ['PUSH' + str(len(bytez))] + bytez
    # Variables connected to with statements
    elif isinstance(code.value, str) and code.value in withargs:
        if height - withargs[code.value] > 16:
            raise Exception("With statement too deep")
        return ['DUP' + str(height - withargs[code.value])]
    # Setting variables connected to with statements
    elif code.value == "set":
        if len(code.args) != 2 or code.args[0].value not in withargs:
            raise Exception(
                "Set expects two arguments, the first being a stack variable")
        if height - withargs[code.args[0].value] > 16:
            raise Exception("With statement too deep")
        return compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height) + \
            ['SWAP' + str(height - withargs[code.args[0].value]), 'POP']
    # Pass statements
    elif code.value == 'pass':
        return []
    # Code length
    elif code.value == '~codelen':
        return ['_sym_codeend']
    # Calldataload equivalent for code
    elif code.value == 'codeload':
        return compile_to_assembly(
            LLLnode.from_list([
                'seq',
                ['codecopy', MemoryPositions.FREE_VAR_SPACE, code.args[0], 32],
                ['mload', MemoryPositions.FREE_VAR_SPACE]
            ]), withargs, existing_labels, break_dest, height)
    # If statements (2 arguments, ie. if x: y)
    elif code.value in ('if', 'if_unchecked') and len(code.args) == 2:
        o = []
        o.extend(
            compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height))
        end_symbol = mksymbol()
        o.extend(['ISZERO', end_symbol, 'JUMPI'])
        o.extend(
            compile_to_assembly(code.args[1], withargs, existing_labels,
                                break_dest, height))
        o.extend([end_symbol, 'JUMPDEST'])
        return o
    # If statements (3 arguments, ie. if x: y, else: z)
    elif code.value == 'if' and len(code.args) == 3:
        o = []
        o.extend(
            compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height))
        mid_symbol = mksymbol()
        end_symbol = mksymbol()
        o.extend(['ISZERO', mid_symbol, 'JUMPI'])
        o.extend(
            compile_to_assembly(code.args[1], withargs, existing_labels,
                                break_dest, height))
        o.extend([end_symbol, 'JUMP', mid_symbol, 'JUMPDEST'])
        o.extend(
            compile_to_assembly(code.args[2], withargs, existing_labels,
                                break_dest, height))
        o.extend([end_symbol, 'JUMPDEST'])
        return o
    # Repeat statements (compiled from for loops)
    # Repeat(memloc, start, rounds, body)
    elif code.value == 'repeat':
        o = []
        loops = num_to_bytearray(code.args[2].value)
        start, continue_dest, end = mksymbol(), mksymbol(), mksymbol()
        o.extend(
            compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height))
        o.extend(
            compile_to_assembly(
                code.args[1],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        o.extend(['PUSH' + str(len(loops))] + loops)
        # stack: memloc, startvalue, rounds
        o.extend(['DUP2', 'DUP4', 'MSTORE', 'ADD', start, 'JUMPDEST'])
        # stack: memloc, exit_index
        o.extend(
            compile_to_assembly(
                code.args[3],
                withargs,
                existing_labels,
                (end, continue_dest, height + 2),
                height + 2,
            ))
        # stack: memloc, exit_index
        o.extend([
            continue_dest,
            'JUMPDEST',
            'DUP2',
            'MLOAD',
            'PUSH1',
            1,
            'ADD',
            'DUP1',
            'DUP4',
            'MSTORE',
        ])
        # stack: len(loops), index memory address, new index
        o.extend([
            'DUP2', 'EQ', 'ISZERO', start, 'JUMPI', end, 'JUMPDEST', 'POP',
            'POP'
        ])
        return o
    # Continue to the next iteration of the for loop
    elif code.value == 'continue':
        if not break_dest:
            raise Exception("Invalid break")
        dest, continue_dest, break_height = break_dest
        return [continue_dest, 'JUMP']
    # Break from inside a for loop
    elif code.value == 'break':
        if not break_dest:
            raise Exception("Invalid break")
        dest, continue_dest, break_height = break_dest
        return ['POP'] * (height - break_height) + [dest, 'JUMP']
    # With statements
    elif code.value == 'with':
        o = []
        o.extend(
            compile_to_assembly(code.args[1], withargs, existing_labels,
                                break_dest, height))
        old = withargs.get(code.args[0].value, None)
        withargs[code.args[0].value] = height
        o.extend(
            compile_to_assembly(
                code.args[2],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        if code.args[2].valency:
            o.extend(['SWAP1', 'POP'])
        else:
            o.extend(['POP'])
        if old is not None:
            withargs[code.args[0].value] = old
        else:
            del withargs[code.args[0].value]
        return o
    # LLL statement (used to contain code inside code)
    elif code.value == 'lll':
        o = []
        begincode = mksymbol()
        endcode = mksymbol()
        o.extend([endcode, 'JUMP', begincode, 'BLANK'])
        # The `append(...)` call here is intentional
        o.append(
            compile_to_assembly(code.args[0], {}, existing_labels, None, 0))
        o.extend([endcode, 'JUMPDEST', begincode, endcode, 'SUB', begincode])
        o.extend(
            compile_to_assembly(code.args[1], withargs, existing_labels,
                                break_dest, height))
        o.extend(['CODECOPY', begincode, endcode, 'SUB'])
        return o
    # Seq (used to piece together multiple statements)
    elif code.value == 'seq':
        o = []
        for arg in code.args:
            o.extend(
                compile_to_assembly(arg, withargs, existing_labels, break_dest,
                                    height))
            if arg.valency == 1 and arg != code.args[-1]:
                o.append('POP')
        return o
    # Seq without popping.
    elif code.value == 'seq_unchecked':
        o = []
        for arg in code.args:
            o.extend(
                compile_to_assembly(arg, withargs, existing_labels, break_dest,
                                    height))
            # if arg.valency == 1 and arg != code.args[-1]:
            #     o.append('POP')
        return o
    # Assure (if false, invalid opcode)
    elif code.value == 'assert_unreachable':
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        end_symbol = mksymbol()
        o.extend([end_symbol, 'JUMPI', 'INVALID', end_symbol, 'JUMPDEST'])
        return o
    # Assert (if false, exit)
    elif code.value == 'assert':
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        o.extend(get_revert())
        return o
    elif code.value == 'assert_reason':
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        mem_start = compile_to_assembly(code.args[1], withargs,
                                        existing_labels, break_dest, height)
        mem_len = compile_to_assembly(code.args[2], withargs, existing_labels,
                                      break_dest, height)
        o.extend(get_revert(mem_start, mem_len))
        return o
    # Unsigned/signed clamp, check less-than
    elif code.value in CLAMP_OP_NAMES:
        if isinstance(code.args[0].value, int) and isinstance(
                code.args[1].value, int):
            # Checks for clamp errors at compile time as opposed to run time
            args_0_val = code.args[0].value
            args_1_val = code.args[1].value
            is_free_of_clamp_errors = any((
                code.value in ('uclamplt', 'clamplt')
                and 0 <= args_0_val < args_1_val,
                code.value in ('uclample', 'clample')
                and 0 <= args_0_val <= args_1_val,
                code.value in ('uclampgt', 'clampgt')
                and 0 <= args_0_val > args_1_val,
                code.value in ('uclampge', 'clampge')
                and 0 <= args_0_val >= args_1_val,
            ))
            if is_free_of_clamp_errors:
                return compile_to_assembly(
                    code.args[0],
                    withargs,
                    existing_labels,
                    break_dest,
                    height,
                )
            else:
                raise Exception(
                    f"Invalid {code.value} with values {code.args[0]} and {code.args[1]}"
                )
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        o.extend(
            compile_to_assembly(
                code.args[1],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        o.extend(['DUP2'])
        # Stack: num num bound
        if code.value == 'uclamplt':
            o.extend(['LT'])
        elif code.value == "clamplt":
            o.extend(['SLT'])
        elif code.value == "uclample":
            o.extend(['GT', 'ISZERO'])
        elif code.value == "clample":
            o.extend(['SGT', 'ISZERO'])
        elif code.value == 'uclampgt':
            o.extend(['GT'])
        elif code.value == "clampgt":
            o.extend(['SGT'])
        elif code.value == "uclampge":
            o.extend(['LT', 'ISZERO'])
        elif code.value == "clampge":
            o.extend(['SLT', 'ISZERO'])
        o.extend(get_revert())
        return o
    # Signed clamp, check against upper and lower bounds
    elif code.value in ('clamp', 'uclamp'):
        comp1 = 'SGT' if code.value == 'clamp' else 'GT'
        comp2 = 'SLT' if code.value == 'clamp' else 'LT'
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        o.extend(
            compile_to_assembly(
                code.args[1],
                withargs,
                existing_labels,
                break_dest,
                height + 1,
            ))
        o.extend(['DUP1'])
        o.extend(
            compile_to_assembly(
                code.args[2],
                withargs,
                existing_labels,
                break_dest,
                height + 3,
            ))
        o.extend(['SWAP1', comp1, 'ISZERO'])
        o.extend(get_revert())
        o.extend(['DUP1', 'SWAP2', 'SWAP1', comp2, 'ISZERO'])
        o.extend(get_revert())
        return o
    # Checks that a value is nonzero
    elif code.value == 'clamp_nonzero':
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        o.extend(['DUP1'])
        o.extend(get_revert())
        return o
    # SHA3 a single value
    elif code.value == 'sha3_32':
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        o.extend([
            'PUSH1', MemoryPositions.FREE_VAR_SPACE, 'MSTORE', 'PUSH1', 32,
            'PUSH1', MemoryPositions.FREE_VAR_SPACE, 'SHA3'
        ])
        return o
    # SHA3 a 64 byte value
    elif code.value == 'sha3_64':
        o = compile_to_assembly(code.args[0], withargs, existing_labels,
                                break_dest, height)
        o.extend(
            compile_to_assembly(code.args[1], withargs, existing_labels,
                                break_dest, height))
        o.extend([
            'PUSH1', MemoryPositions.FREE_VAR_SPACE2, 'MSTORE', 'PUSH1',
            MemoryPositions.FREE_VAR_SPACE, 'MSTORE', 'PUSH1', 64, 'PUSH1',
            MemoryPositions.FREE_VAR_SPACE, 'SHA3'
        ])
        return o
    # <= operator
    elif code.value == 'le':
        return compile_to_assembly(
            LLLnode.from_list([
                'iszero',
                ['gt', code.args[0], code.args[1]],
            ]), withargs, existing_labels, break_dest, height)
    # >= operator
    elif code.value == 'ge':
        return compile_to_assembly(
            LLLnode.from_list([
                'iszero',
                ['lt', code.args[0], code.args[1]],
            ]), withargs, existing_labels, break_dest, height)
    # <= operator
    elif code.value == 'sle':
        return compile_to_assembly(
            LLLnode.from_list([
                'iszero',
                ['sgt', code.args[0], code.args[1]],
            ]), withargs, existing_labels, break_dest, height)
    # >= operator
    elif code.value == 'sge':
        return compile_to_assembly(
            LLLnode.from_list([
                'iszero',
                ['slt', code.args[0], code.args[1]],
            ]), withargs, existing_labels, break_dest, height)
    # != operator
    elif code.value == 'ne':
        return compile_to_assembly(
            LLLnode.from_list([
                'iszero',
                ['eq', code.args[0], code.args[1]],
            ]), withargs, existing_labels, break_dest, height)
    # e.g. 95 -> 96, 96 -> 96, 97 -> 128
    elif code.value == "ceil32":
        return compile_to_assembly(
            LLLnode.from_list([
                'with', '_val', code.args[0],
                [
                    'sub',
                    ['add', '_val', 31],
                    ['mod', ['sub', '_val', 1], 32],
                ]
            ]), withargs, existing_labels, break_dest, height)
    # # jump to a symbol
    elif code.value == 'goto':
        return ['_sym_' + str(code.args[0]), 'JUMP']
    elif isinstance(code.value, str) and code.value.startswith('_sym_'):
        return code.value
    # set a symbol as a location.
    elif code.value == 'label':
        label_name = str(code.args[0])

        if label_name in existing_labels:
            raise Exception(f'Label with name {label_name} already exists!')
        else:
            existing_labels.add(label_name)

        return ['_sym_' + label_name, 'JUMPDEST']
    # inject debug opcode.
    elif code.value == 'debugger':
        return mkdebug(pc_debugger=False, pos=code.pos)
    # inject debug opcode.
    elif code.value == 'pc_debugger':
        return mkdebug(pc_debugger=True, pos=code.pos)
    else:
        raise Exception("Weird code element: " + repr(code))
Пример #19
0
def make_byte_array_copier(destination, source, pos=None):
    if not isinstance(source.typ, ByteArrayLike):
        btype = 'byte array' if isinstance(destination.typ,
                                           ByteArrayType) else 'string'
        raise TypeMismatch(f"Can only set a {btype} to another {btype}", pos)
    if isinstance(
            source.typ,
            ByteArrayLike) and source.typ.maxlen > destination.typ.maxlen:
        raise TypeMismatch(
            f"Cannot cast from greater max-length {source.typ.maxlen} to shorter "
            f"max-length {destination.typ.maxlen}")

    # stricter check for zeroing a byte array.
    if isinstance(source.typ, ByteArrayLike):
        if source.value is None and source.typ.maxlen != destination.typ.maxlen:
            raise TypeMismatch(
                f"Bad type for clearing bytes: expected {destination.typ}"
                f" but got {source.typ}")

    # Special case: memory to memory
    if source.location == "memory" and destination.location == "memory":
        gas_calculation = GAS_IDENTITY + GAS_IDENTITYWORD * (
            ceil32(source.typ.maxlen) // 32)
        o = LLLnode.from_list(
            [
                'with', '_source', source,
                [
                    'with', '_sz', ['add', 32, ['mload', '_source']],
                    [
                        'assert',
                        [
                            'call', ['gas'], 4, 0, '_source', '_sz',
                            destination, '_sz'
                        ]
                    ]
                ]
            ],  # noqa: E501
            typ=None,
            add_gas_estimate=gas_calculation,
            annotation='Memory copy')
        return o

    if source.value is None:
        pos_node = source
    else:
        pos_node = LLLnode.from_list('_pos',
                                     typ=source.typ,
                                     location=source.location)
    # Get the length
    if source.value is None:
        length = 1
    elif source.location == "memory":
        length = ['add', ['mload', '_pos'], 32]
    elif source.location == "storage":
        length = ['add', ['sload', '_pos'], 32]
        pos_node = LLLnode.from_list(
            ['sha3_32', pos_node],
            typ=source.typ,
            location=source.location,
        )
    else:
        raise CompilerPanic(f"Unsupported location: {source.location}")
    if destination.location == "storage":
        destination = LLLnode.from_list(
            ['sha3_32', destination],
            typ=destination.typ,
            location=destination.location,
        )
    # Maximum theoretical length
    max_length = 32 if source.value is None else source.typ.maxlen + 32
    return LLLnode.from_list([
        'with', '_pos', 0 if source.value is None else source,
        make_byte_slice_copier(
            destination, pos_node, length, max_length, pos=pos)
    ],
                             typ=None)
Пример #20
0
    def parse_for_list(self):
        with self.context.range_scope():
            iter_list_node = Expr(self.stmt.iter, self.context).lll_node
        if not isinstance(iter_list_node.typ.subtype,
                          BaseType):  # Sanity check on list subtype.
            raise StructureException(
                'For loops allowed only on basetype lists.', self.stmt.iter)
        iter_var_type = (self.context.vars.get(self.stmt.iter.id).typ
                         if isinstance(self.stmt.iter, sri_ast.Name) else None)
        subtype = iter_list_node.typ.subtype.typ
        varname = self.stmt.target.id
        value_pos = self.context.new_variable(
            varname,
            BaseType(subtype),
        )
        i_pos_raw_name = '_index_for_' + varname
        i_pos = self.context.new_internal_variable(
            i_pos_raw_name,
            BaseType(subtype),
        )
        self.context.forvars[varname] = True

        # Is a list that is already allocated to memory.
        if iter_var_type:

            list_name = self.stmt.iter.id
            # make sure list cannot be altered whilst iterating.
            with self.context.in_for_loop_scope(list_name):
                iter_var = self.context.vars.get(self.stmt.iter.id)
                if iter_var.location == 'calldata':
                    fetcher = 'calldataload'
                elif iter_var.location == 'memory':
                    fetcher = 'mload'
                else:
                    raise CompilerPanic(
                        f'List iteration only supported on in-memory types {self.expr}',
                    )
                body = [
                    'seq',
                    [
                        'mstore',
                        value_pos,
                        [
                            fetcher,
                            [
                                'add', iter_var.pos,
                                ['mul', ['mload', i_pos], 32]
                            ]
                        ],
                    ],
                    parse_body(self.stmt.body, self.context)
                ]
                o = LLLnode.from_list(
                    ['repeat', i_pos, 0, iter_var.size, body],
                    typ=None,
                    pos=getpos(self.stmt))

        # List gets defined in the for statement.
        elif isinstance(self.stmt.iter, sri_ast.List):
            # Allocate list to memory.
            count = iter_list_node.typ.count
            tmp_list = LLLnode.from_list(obj=self.context.new_placeholder(
                ListType(iter_list_node.typ.subtype, count)),
                                         typ=ListType(
                                             iter_list_node.typ.subtype,
                                             count),
                                         location='memory')
            setter = make_setter(tmp_list,
                                 iter_list_node,
                                 'memory',
                                 pos=getpos(self.stmt))
            body = [
                'seq',
                [
                    'mstore', value_pos,
                    [
                        'mload',
                        ['add', tmp_list, ['mul', ['mload', i_pos], 32]]
                    ]
                ],
                parse_body(self.stmt.body, self.context)
            ]
            o = LLLnode.from_list(
                ['seq', setter, ['repeat', i_pos, 0, count, body]],
                typ=None,
                pos=getpos(self.stmt))

        # List contained in storage.
        elif isinstance(self.stmt.iter, sri_ast.Attribute):
            count = iter_list_node.typ.count
            list_name = iter_list_node.annotation

            # make sure list cannot be altered whilst iterating.
            with self.context.in_for_loop_scope(list_name):
                body = [
                    'seq',
                    [
                        'mstore', value_pos,
                        [
                            'sload',
                            [
                                'add', ['sha3_32', iter_list_node],
                                ['mload', i_pos]
                            ]
                        ]
                    ],
                    parse_body(self.stmt.body, self.context),
                ]
                o = LLLnode.from_list(
                    ['seq', ['repeat', i_pos, 0, count, body]],
                    typ=None,
                    pos=getpos(self.stmt))

        del self.context.vars[varname]
        # this kind of open access to the vars dict should be disallowed.
        # we should use member functions to provide an API for these kinds
        # of operations.
        del self.context.vars[self.context._mangle(i_pos_raw_name)]
        del self.context.forvars[varname]
        return o
Пример #21
0
    def __init__(self, m_bits, signed):
        if not (0 < m_bits <= 256 and 0 == m_bits % 8):
            raise CompilerPanic('Invalid M provided for GIntM')

        self.m_bits = m_bits
        self.signed = signed
Пример #22
0
def make_byte_slice_copier(destination, source, length, max_length, pos=None):
    # Special case: memory to memory
    if source.location == "memory" and destination.location == "memory":
        return LLLnode.from_list(
            [
                'with', '_l', max_length,
                [
                    'pop',
                    ['call', ['gas'], 4, 0, source, '_l', destination, '_l']
                ]
            ],
            typ=None,
            annotation=f'copy byte slice dest: {str(destination)}')

    # special case: rhs is zero
    if source.value is None:

        if destination.location == 'memory':
            return mzero(destination, max_length)

        else:
            loader = 0
    # Copy over data
    elif source.location == "memory":
        loader = [
            'mload',
            [
                'add', '_pos',
                ['mul', 32, ['mload', MemoryPositions.FREE_LOOP_INDEX]]
            ]
        ]
    elif source.location == "storage":
        loader = [
            'sload',
            ['add', '_pos', ['mload', MemoryPositions.FREE_LOOP_INDEX]]
        ]
    else:
        raise CompilerPanic(f'Unsupported location: {source.location}')
    # Where to paste it?
    if destination.location == "memory":
        setter = [
            'mstore',
            [
                'add', '_opos',
                ['mul', 32, ['mload', MemoryPositions.FREE_LOOP_INDEX]]
            ], loader
        ]
    elif destination.location == "storage":
        setter = [
            'sstore',
            ['add', '_opos', ['mload', MemoryPositions.FREE_LOOP_INDEX]],
            loader
        ]
    else:
        raise CompilerPanic(f"Unsupported location: {destination.location}")
    # Check to see if we hit the length
    checker = [
        'if',
        [
            'gt', ['mul', 32, ['mload', MemoryPositions.FREE_LOOP_INDEX]],
            '_actual_len'
        ], 'break'
    ]
    # Make a loop to do the copying
    ipos = 0 if source.value is None else source
    o = [
        'with', '_pos', ipos,
        [
            'with', '_opos', destination,
            [
                'with', '_actual_len', length,
                [
                    'repeat', MemoryPositions.FREE_LOOP_INDEX, 0,
                    (max_length + 31) // 32, ['seq', checker, setter]
                ]
            ]
        ]
    ]
    return LLLnode.from_list(
        o,
        typ=None,
        annotation=f'copy byte slice src: {source} dst: {destination}',
        pos=pos,
    )