예제 #1
0
 def hexstring(self):
     orignum = self.expr.value
     if len(orignum) == 42:
         if checksum_encode(orignum) != orignum:
             raise InvalidLiteral(
                 "Address checksum mismatch. If you are sure this is the "
                 f"right address, the correct checksummed form is: {checksum_encode(orignum)}",
                 self.expr
             )
         return LLLnode.from_list(
             int(self.expr.value, 16),
             typ=BaseType('address', is_literal=True),
             pos=getpos(self.expr),
         )
     elif len(orignum) == 66:
         return LLLnode.from_list(
             int(self.expr.value, 16),
             typ=BaseType('bytes32', is_literal=True),
             pos=getpos(self.expr),
         )
     else:
         raise InvalidLiteral(
             f"Cannot read 0x value with length {len(orignum)}. Expecting 42 (address "
             "incl 0x) or 66 (bytes32 incl 0x)",
             self.expr
         )
예제 #2
0
    def _assert_reason(self, test_expr, msg):
        if isinstance(msg, sri_ast.Name) and msg.id == 'UNREACHABLE':
            return self._assert_unreachable(test_expr, msg)

        if not isinstance(msg, sri_ast.Str):
            raise StructureException(
                'Reason parameter of assert needs to be a literal string '
                '(or UNREACHABLE constant).', msg)
        if len(msg.s.strip()) == 0:
            raise StructureException('Empty reason string not allowed.',
                                     self.stmt)
        reason_str = msg.s.strip()
        sig_placeholder = self.context.new_placeholder(BaseType(32))
        arg_placeholder = self.context.new_placeholder(BaseType(32))
        reason_str_type = ByteArrayType(len(reason_str))
        placeholder_bytes = Expr(msg, self.context).lll_node
        method_id = fourbytes_to_int(keccak256(b"Error(string)")[:4])
        assert_reason = [
            'seq',
            ['mstore', sig_placeholder, method_id],
            ['mstore', arg_placeholder, 32],
            placeholder_bytes,
            [
                'assert_reason',
                test_expr,
                int(sig_placeholder + 28),
                int(4 + get_size_of_type(reason_str_type) * 32),
            ],
        ]
        return LLLnode.from_list(assert_reason,
                                 typ=None,
                                 pos=getpos(self.stmt))
예제 #3
0
def test_mapping_node_types():

    with raises(Exception):
        MappingType(int, int)

    node1 = MappingType(BaseType('int128'), BaseType('int128'))
    node2 = MappingType(BaseType('int128'), BaseType('int128'))
    assert node1 == node2
    assert str(node1) == "map(int128, int128)"
예제 #4
0
def test_canonicalize_type():
    # Non-basetype not allowed
    with raises(Exception):
        canonicalize_type(int)
    # List of byte arrays not allowed
    a = ListType(ByteArrayType(12), 2)
    with raises(Exception):
        canonicalize_type(a)
    # Test ABI format of multiple args.
    c = TupleType([BaseType('int128'), BaseType('address')])
    assert canonicalize_type(c) == "(int128,address)"
예제 #5
0
def keccak256_helper(expr, args, kwargs, context):
    sub = args[0]
    # Can hash literals
    if isinstance(sub, bytes):
        return LLLnode.from_list(bytes_to_int(keccak256(sub)),
                                 typ=BaseType('bytes32'),
                                 pos=getpos(expr))
    # Can hash bytes32 objects
    if is_base_type(sub.typ, 'bytes32'):
        return LLLnode.from_list(
            [
                'seq', ['mstore', MemoryPositions.FREE_VAR_SPACE, sub],
                ['sha3', MemoryPositions.FREE_VAR_SPACE, 32]
            ],
            typ=BaseType('bytes32'),
            pos=getpos(expr),
        )
    # Copy the data to an in-memory array
    if sub.location == "memory":
        # If we are hashing a value in memory, no need to copy it, just hash in-place
        return LLLnode.from_list(
            [
                'with', '_sub', sub,
                ['sha3', ['add', '_sub', 32], ['mload', '_sub']]
            ],
            typ=BaseType('bytes32'),
            pos=getpos(expr),
        )
    elif sub.location == "storage":
        lengetter = LLLnode.from_list(['sload', ['sha3_32', '_sub']],
                                      typ=BaseType('int128'))
    else:
        # This should never happen, but just left here for future compiler-writers.
        raise Exception(
            f"Unsupported location: {sub.location}")  # pragma: no test
    placeholder = context.new_placeholder(sub.typ)
    placeholder_node = LLLnode.from_list(placeholder,
                                         typ=sub.typ,
                                         location='memory')
    copier = make_byte_array_copier(
        placeholder_node,
        LLLnode.from_list('_sub', typ=sub.typ, location=sub.location),
    )
    return LLLnode.from_list([
        'with',
        '_sub',
        sub,
        ['seq', copier, ['sha3', ['add', placeholder, 32], lengetter]],
    ],
                             typ=BaseType('bytes32'),
                             pos=getpos(expr))
예제 #6
0
 def integer(self):
     # Literal (mostly likely) becomes int128
     if SizeLimits.in_bounds('int128', self.expr.n) or self.expr.n < 0:
         return LLLnode.from_list(
             self.expr.n,
             typ=BaseType('int128', is_literal=True),
             pos=getpos(self.expr),
         )
     # Literal is large enough (mostly likely) becomes uint256.
     else:
         return LLLnode.from_list(
             self.expr.n,
             typ=BaseType('uint256', is_literal=True),
             pos=getpos(self.expr),
         )
예제 #7
0
def to_address(expr, args, kwargs, context):
    in_arg = args[0]

    return LLLnode(value=in_arg.value,
                   args=in_arg.args,
                   typ=BaseType('address'),
                   pos=getpos(expr))
예제 #8
0
    def variables(self):

        if self.expr.id == 'self':
            return LLLnode.from_list(['address'], typ='address', pos=getpos(self.expr))
        elif self.expr.id in self.context.vars:
            var = self.context.vars[self.expr.id]
            return LLLnode.from_list(
                var.pos,
                typ=var.typ,
                location=var.location,  # either 'memory' or 'calldata' storage is handled above.
                pos=getpos(self.expr),
                annotation=self.expr.id,
                mutable=var.mutable,
            )

        elif self.expr.id in BUILTIN_CONSTANTS:
            obj, typ = BUILTIN_CONSTANTS[self.expr.id]
            return LLLnode.from_list(
                [obj],
                typ=BaseType(typ, is_literal=True),
                pos=getpos(self.expr))
        elif self.context.constants.ast_is_constant(self.expr):
            return self.context.constants.get_constant(self.expr.id, self.context)
        else:
            raise VariableDeclarationException(f"Undeclared variable: {self.expr.id}", self.expr)
예제 #9
0
 def constants(self):
     if self.expr.value is True:
         return LLLnode.from_list(
             1,
             typ=BaseType('bool', is_literal=True),
             pos=getpos(self.expr),
         )
     elif self.expr.value is False:
         return LLLnode.from_list(
             0,
             typ=BaseType('bool', is_literal=True),
             pos=getpos(self.expr),
         )
     elif self.expr.value is None:
         # block None
         raise InvalidLiteral(
                 'None is not allowed in srilang'
                 '(use a default value or built-in `empty()`')
     else:
         raise Exception(f"Unknown name constant: {self.expr.value.value}")
예제 #10
0
def to_bool(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to bool",
                expr,
            )
        else:
            num = byte_array_to_num(in_arg, expr, 'uint256')
            return LLLnode.from_list(['iszero', ['iszero', num]],
                                     typ=BaseType('bool'),
                                     pos=getpos(expr))

    else:
        return LLLnode.from_list(['iszero', ['iszero', in_arg]],
                                 typ=BaseType('bool'),
                                 pos=getpos(expr))
예제 #11
0
def test_bytearray_node_type():

    node1 = ByteArrayType(12)
    node2 = ByteArrayType(12)

    assert node1 == node2

    node3 = ByteArrayType(13)
    node4 = BaseType('int128')

    assert node1 != node3
    assert node1 != node4
예제 #12
0
def byte_array_to_num(
    arg,
    expr,
    out_type,
    offset=32,
):
    if arg.location == "memory":
        lengetter = LLLnode.from_list(['mload', '_sub'],
                                      typ=BaseType('int128'))
        first_el_getter = LLLnode.from_list(['mload', ['add', 32, '_sub']],
                                            typ=BaseType('int128'))
    elif arg.location == "storage":
        lengetter = LLLnode.from_list(['sload', ['sha3_32', '_sub']],
                                      typ=BaseType('int128'))
        first_el_getter = LLLnode.from_list(
            ['sload', ['add', 1, ['sha3_32', '_sub']]], typ=BaseType('int128'))
    if out_type == 'int128':
        result = [
            'clamp', ['mload', MemoryPositions.MINNUM],
            ['div', '_el1', ['exp', 256, ['sub', 32, '_len']]],
            ['mload', MemoryPositions.MAXNUM]
        ]
    elif out_type == 'uint256':
        result = ['div', '_el1', ['exp', 256, ['sub', offset, '_len']]]
    return LLLnode.from_list([
        'with', '_sub', arg,
        [
            'with', '_el1', first_el_getter,
            [
                'with',
                '_len',
                ['clamp', 0, lengetter, 32],
                result,
            ]
        ]
    ],
                             typ=BaseType(out_type),
                             annotation=f'bytearray to number ({out_type})')
예제 #13
0
def to_bytes32(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _len = get_type(in_arg)

    if input_type == 'bytes':
        if _len > 32:
            raise TypeMismatch(
                f"Unable to convert bytes[{_len}] to bytes32, max length is too "
                "large.")

        if in_arg.location == "memory":
            return LLLnode.from_list(['mload', ['add', in_arg, 32]],
                                     typ=BaseType('bytes32'))
        elif in_arg.location == "storage":
            return LLLnode.from_list(
                ['sload', ['add', ['sha3_32', in_arg], 1]],
                typ=BaseType('bytes32'))

    else:
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType('bytes32'),
                       pos=getpos(expr))
예제 #14
0
def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None):
    from srilang.parser.function_definitions.utils import (
        get_nonreentrant_lock
    )
    _, nonreentrant_post = get_nonreentrant_lock(context.sig, context.global_ctx)
    if context.is_private:
        if loop_memory_position is None:
            loop_memory_position = context.new_placeholder(typ=BaseType('uint256'))

        # Make label for stack push loop.
        label_id = '_'.join([str(x) for x in (context.method_id, stmt.lineno, stmt.col_offset)])
        exit_label = f'make_return_loop_exit_{label_id}'
        start_label = f'make_return_loop_start_{label_id}'

        # Push prepared data onto the stack,
        # in reverse order so it can be popped of in order.
        if isinstance(begin_pos, int) and isinstance(_size, int):
            # static values, unroll the mloads instead.
            mloads = [
                ['mload', pos] for pos in range(begin_pos, _size, 32)
            ]
            return ['seq_unchecked'] + mloads + nonreentrant_post + \
                [['jump', ['mload', context.callback_ptr]]]
        else:
            mloads = [
                'seq_unchecked',
                ['mstore', loop_memory_position, _size],
                ['label', start_label],
                [  # maybe exit loop / break.
                    'if',
                    ['le', ['mload', loop_memory_position], 0],
                    ['goto', exit_label]
                ],
                [  # push onto stack
                    'mload',
                    ['add', begin_pos, ['sub', ['mload', loop_memory_position], 32]]
                ],
                [  # decrement i by 32.
                    'mstore',
                    loop_memory_position,
                    ['sub', ['mload', loop_memory_position], 32],
                ],
                ['goto', start_label],
                ['label', exit_label]
            ]
            return ['seq_unchecked'] + [mloads] + nonreentrant_post + \
                [['jump', ['mload', context.callback_ptr]]]
    else:
        return ['seq_unchecked'] + nonreentrant_post + [['return', begin_pos, _size]]
예제 #15
0
 def decimal(self):
     numstring, num, den = get_number_as_fraction(self.expr, self.context)
     # if not SizeLimits.in_bounds('decimal', num // den):
     # if not SizeLimits.MINDECIMAL * den <= num <= SizeLimits.MAXDECIMAL * den:
     if not (SizeLimits.MINNUM * den <= num <= SizeLimits.MAXNUM * den):
         raise InvalidLiteral("Number out of range: " + numstring, self.expr)
     if DECIMAL_DIVISOR % den:
         raise InvalidLiteral(
             "Type 'decimal' has maximum 10 decimal places",
             self.expr
         )
     return LLLnode.from_list(
         num * DECIMAL_DIVISOR // den,
         typ=BaseType('decimal', is_literal=True),
         pos=getpos(self.expr),
     )
예제 #16
0
    def ann_assign(self):
        with self.context.assignment_scope():
            typ = parse_type(
                self.stmt.annotation,
                location='memory',
                custom_structs=self.context.structs,
                constants=self.context.constants,
            )
            if isinstance(self.stmt.target, sri_ast.Attribute):
                raise TypeMismatch(
                    f'May not set type for field {self.stmt.target.attr}',
                    self.stmt,
                )
            varname = self.stmt.target.id
            pos = self.context.new_variable(varname, typ)
            if self.stmt.value is None:
                raise StructureException(
                    'New variables must be initialized explicitly', self.stmt)

            sub = Expr(self.stmt.value, self.context).lll_node

            is_literal_bytes32_assign = (isinstance(sub.typ, ByteArrayType)
                                         and sub.typ.maxlen == 32
                                         and isinstance(typ, BaseType)
                                         and typ.typ == 'bytes32'
                                         and sub.typ.is_literal)

            # If bytes[32] to bytes32 assignment rewrite sub as bytes32.
            if is_literal_bytes32_assign:
                sub = LLLnode(
                    bytes_to_int(self.stmt.value.s),
                    typ=BaseType('bytes32'),
                    pos=getpos(self.stmt),
                )

            self._check_valid_assign(sub)
            self._check_same_variable_assign(sub)
            variable_loc = LLLnode.from_list(
                pos,
                typ=typ,
                location='memory',
                pos=getpos(self.stmt),
            )
            o = make_setter(variable_loc, sub, 'memory', pos=getpos(self.stmt))

            return o
예제 #17
0
def to_uint256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'num_literal':
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds('uint256', in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType('uint256', ),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds('uint256', math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType('uint256'),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type == 'int128':
        return LLLnode.from_list(['clampge', in_arg, 0],
                                 typ=BaseType('uint256'),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == 'decimal':
        return LLLnode.from_list(
            ['div', ['clampge', in_arg, 0], DECIMAL_DIVISOR],
            typ=BaseType('uint256'),
            pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == 'bool':
        return LLLnode.from_list(in_arg,
                                 typ=BaseType('uint256'),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ('bytes32', 'address'):
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType('uint256'),
                       pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise InvalidLiteral(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to uint256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, 'uint256')

    else:
        raise InvalidLiteral(f"Invalid input for uint256: {in_arg}", expr)
예제 #18
0
    def from_list(cls,
                  obj: Any,
                  typ: 'BaseType' = None,
                  location: str = None,
                  pos: Tuple[int, int] = None,
                  annotation: Optional[str] = None,
                  mutable: bool = True,
                  add_gas_estimate: int = 0,
                  valency: Optional[int] = None) -> 'LLLnode':
        if isinstance(typ, str):
            typ = BaseType(typ)

        if isinstance(obj, LLLnode):
            # note: this modify-and-returnclause is a little weird since
            # the input gets modified. CC 20191121.
            if typ is not None:
                obj.typ = typ
            if obj.pos is None:
                obj.pos = pos
            if obj.location is None:
                obj.location = location
            return obj
        elif not isinstance(obj, list):
            return cls(
                obj,
                [],
                typ,
                location=location,
                pos=pos,
                annotation=annotation,
                mutable=mutable,
                add_gas_estimate=add_gas_estimate,
            )
        else:
            return cls(
                obj[0],
                [cls.from_list(o, pos=pos) for o in obj[1:]],
                typ,
                location=location,
                pos=pos,
                annotation=annotation,
                mutable=mutable,
                add_gas_estimate=add_gas_estimate,
                valency=valency,
            )
예제 #19
0
def base_type_conversion(orig, frm, to, pos, in_function_call=False):
    orig = unwrap_location(orig)

    # do the base type check so we can use BaseType attributes
    if not isinstance(frm, BaseType) or not isinstance(to, BaseType):
        raise TypeMismatch(
            f"Base type conversion from or to non-base type: {frm} {to}", pos)

    if getattr(frm, 'is_literal', False):
        if frm.typ in ('int128', 'uint256'):
            if not SizeLimits.in_bounds(frm.typ, orig.value):
                raise InvalidLiteral(f"Number out of range: {orig.value}", pos)

        if to.typ in ('int128', 'uint256'):
            if not SizeLimits.in_bounds(to.typ, orig.value):
                raise InvalidLiteral(f"Number out of range: {orig.value}", pos)

    is_decimal_int128_conversion = frm.typ == 'int128' and to.typ == 'decimal'
    is_same_type = frm.typ == to.typ
    is_literal_conversion = frm.is_literal and (frm.typ, to.typ) == ('int128',
                                                                     'uint256')
    is_address_conversion = isinstance(frm,
                                       ContractType) and to.typ == 'address'
    if not (is_same_type or is_literal_conversion or is_address_conversion
            or is_decimal_int128_conversion):
        raise TypeMismatch(
            f"Typecasting from base type {frm} to {to} unavailable", pos)

    # handle None value inserted by `empty()`
    if orig.value is None:
        return LLLnode.from_list(0, typ=to)

    if is_decimal_int128_conversion:
        return LLLnode.from_list(
            ['mul', orig, DECIMAL_DIVISOR],
            typ=BaseType('decimal'),
        )

    return LLLnode(orig.value,
                   orig.args,
                   typ=to,
                   add_gas_estimate=orig.add_gas_estimate)
예제 #20
0
def test_get_size_of_type():
    assert get_size_of_type(BaseType('int128')) == 1
    assert get_size_of_type(ByteArrayType(12)) == 3
    assert get_size_of_type(ByteArrayType(33)) == 4
    assert get_size_of_type(ListType(BaseType('int128'), 10)) == 10

    _tuple = TupleType([BaseType('int128'), BaseType('decimal')])
    assert get_size_of_type(_tuple) == 2

    _struct = StructType({
        'a': BaseType('int128'),
        'b': BaseType('decimal')
    }, 'Foo')
    assert get_size_of_type(_struct) == 2

    # Don't allow unknow types.
    with raises(Exception):
        get_size_of_type(int)

    # Maps are not supported for function arguments or outputs
    with raises(Exception):
        get_size_of_type(MappingType(BaseType('int128'), BaseType('int128')))
예제 #21
0
    def arithmetic(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatch(
                f"Unsupported types for arithmetic op: {left.typ} {right.typ}",
                self.expr,
            )

        arithmetic_pair = {left.typ.typ, right.typ.typ}
        pos = getpos(self.expr)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value):
                right = LLLnode.from_list(
                    right.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=pos,
                )

            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value):
                left = LLLnode.from_list(
                    left.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=pos,
                )

        if left.typ.typ == "decimal" and isinstance(self.expr.op, sri_ast.Pow):
            raise TypeMismatch(
                "Cannot perform exponentiation on decimal values.",
                self.expr,
            )

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatch(
                f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.",
                self.expr,
            )

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (sri_ast.Add, sri_ast.Sub)):
            new_typ = BaseType(ltyp)
            op = 'add' if isinstance(self.expr.op, sri_ast.Add) else 'sub'

            if ltyp == 'uint256' and isinstance(self.expr.op, sri_ast.Add):
                # safeadd
                arith = ['seq',
                         ['assert', ['ge', ['add', 'l', 'r'], 'l']],
                         ['add', 'l', 'r']]

            elif ltyp == 'uint256' and isinstance(self.expr.op, sri_ast.Sub):
                # safesub
                arith = ['seq',
                         ['assert', ['ge', 'l', 'r']],
                         ['sub', 'l', 'r']]

            elif ltyp == rtyp:
                arith = [op, 'l', 'r']

            else:
                raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, sri_ast.Mult):
            new_typ = BaseType(ltyp)
            if ltyp == rtyp == 'uint256':
                arith = ['with', 'ans', ['mul', 'l', 'r'],
                         ['seq',
                             ['assert',
                                 ['or',
                                     ['eq', ['div', 'ans', 'l'], 'r'],
                                     ['iszero', 'l']]],
                             'ans']]

            elif ltyp == rtyp == 'int128':
                # TODO should this be 'smul' (note edge cases in YP for smul)
                arith = ['mul', 'l', 'r']

            elif ltyp == rtyp == 'decimal':
                # TODO should this be smul
                arith = ['with', 'ans', ['mul', 'l', 'r'],
                         ['seq',
                             ['assert',
                                 ['or',
                                     ['eq', ['sdiv', 'ans', 'l'], 'r'],
                                     ['iszero', 'l']]],
                             ['sdiv', 'ans', DECIMAL_DIVISOR]]]
            else:
                raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, sri_ast.Div):
            if right.typ.is_literal and right.value == 0:
                raise ZeroDivisionException("Cannot divide by 0.", self.expr)

            new_typ = BaseType(ltyp)
            if ltyp == rtyp == 'uint256':
                arith = ['div', 'l', ['clamp_nonzero', 'r']]

            elif ltyp == rtyp == 'int128':
                arith = ['sdiv', 'l', ['clamp_nonzero', 'r']]

            elif ltyp == rtyp == 'decimal':
                arith = ['sdiv',
                         # TODO check overflow cases, also should it be smul
                         ['mul', 'l', DECIMAL_DIVISOR],
                         ['clamp_nonzero', 'r']]

            else:
                raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, sri_ast.Mod):
            if right.typ.is_literal and right.value == 0:
                raise ZeroDivisionException("Cannot calculate modulus of 0.", self.expr)

            new_typ = BaseType(ltyp)

            if ltyp == rtyp == 'uint256':
                arith = ['mod', 'l', ['clamp_nonzero', 'r']]
            elif ltyp == rtyp:
                # TODO should this be regular mod
                arith = ['smod', 'l', ['clamp_nonzero', 'r']]

            else:
                raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, sri_ast.Pow):
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, sri_ast.Name):
                raise TypeMismatch(
                    "Cannot use dynamic values as exponents, for unit base types",
                    self.expr,
                )
            new_typ = BaseType(ltyp)

            if ltyp == rtyp == 'uint256':
                arith = ['seq',
                         ['assert',
                             ['or',
                                 # r == 1 | iszero(r)
                                 # could be simplified to ~(r & 1)
                                 ['or', ['eq', 'r', 1], ['iszero', 'r']],
                                 ['lt', 'l', ['exp', 'l', 'r']]]],
                         ['exp', 'l', 'r']]
            elif ltyp == rtyp == 'int128':
                arith = ['exp', 'l', 'r']

            else:
                raise TypeMismatch('Only whole number exponents are supported', self.expr)
        else:
            raise StructureException(f"Unsupported binary operator: {self.expr.op}", self.expr)

        p = ['seq']

        if new_typ.typ == 'int128':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                arith,
                ['mload', MemoryPositions.MAXNUM],
            ])
        elif new_typ.typ == 'decimal':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINDECIMAL],
                arith,
                ['mload', MemoryPositions.MAXDECIMAL],
            ])
        elif new_typ.typ == 'uint256':
            p.append(arith)
        else:
            raise Exception(f"{arith} {new_typ}")

        p = ['with', 'l', left, ['with', 'r', right, p]]
        return LLLnode.from_list(p, typ=new_typ, pos=pos)
예제 #22
0
    def parse_for(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, sri_ast.Call):
            if isinstance(self.stmt.iter, sri_ast.Subscript):
                raise StructureException("Cannot iterate over a nested list",
                                         self.stmt.iter)
            raise StructureException(
                f"Cannot iterate over '{type(self.stmt.iter).__name__}' object",
                self.stmt.iter)
        if getattr(self.stmt.iter.func, 'id', None) != "range":
            raise StructureException(
                "Non-literals cannot be used as loop range",
                self.stmt.iter.func)
        if len(self.stmt.iter.args) not in {1, 2}:
            raise StructureException(
                f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}",
                self.stmt.iter.func)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, sri_ast.BinOp) or not isinstance(
                        arg1.op, sri_ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if not sri_ast.compare_nodes(arg0, arg1.left):
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {sri_ast.ast_to_dict(arg0)} {sri_ast.ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                raise StructureException(
                    f"For loop has invalid number of iterations ({r}),"
                    " the value must be greater than zero", self.stmt.iter)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o
예제 #23
0
def get_length(arg):
    if arg.location == "memory":
        return LLLnode.from_list(['mload', arg], typ=BaseType('int128'))
    elif arg.location == "storage":
        return LLLnode.from_list(['sload', ['sha3_32', arg]],
                                 typ=BaseType('int128'))
예제 #24
0
 def attribute(self):
     # x.balance: balance of address x
     if self.expr.attr == 'balance':
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if not is_base_type(addr.typ, 'address'):
             raise TypeMismatch(
                 "Type mismatch: balance keyword expects an address as input",
                 self.expr
             )
         if (
             isinstance(self.expr.value, sri_ast.Name) and
             self.expr.value.id == "self" and
             version_check(begin="istanbul")
         ):
             seq = ['selfbalance']
         else:
             seq = ['balance', addr]
         return LLLnode.from_list(
             seq,
             typ=BaseType('uint256'),
             location=None,
             pos=getpos(self.expr),
         )
     # x.codesize: codesize of address x
     elif self.expr.attr == 'codesize' or self.expr.attr == 'is_contract':
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if not is_base_type(addr.typ, 'address'):
             raise TypeMismatch(
                 "Type mismatch: codesize keyword expects an address as input",
                 self.expr,
             )
         if self.expr.attr == 'codesize':
             eval_code = ['extcodesize', addr]
             output_type = 'int128'
         else:
             eval_code = ['gt', ['extcodesize', addr], 0]
             output_type = 'bool'
         return LLLnode.from_list(
             eval_code,
             typ=BaseType(output_type),
             location=None,
             pos=getpos(self.expr),
         )
     # x.codehash: keccak of address x
     elif self.expr.attr == 'codehash':
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if not is_base_type(addr.typ, 'address'):
             raise TypeMismatch(
                 "codehash keyword expects an address as input",
                 self.expr,
             )
         if not version_check(begin="constantinople"):
             raise EvmVersionException(
                 "address.codehash is unavailable prior to constantinople ruleset",
                 self.expr
             )
         return LLLnode.from_list(
             ['extcodehash', addr],
             typ=BaseType('bytes32'),
             location=None,
             pos=getpos(self.expr)
         )
     # self.x: global attribute
     elif isinstance(self.expr.value, sri_ast.Name) and self.expr.value.id == "self":
         if self.expr.attr not in self.context.globals:
             raise VariableDeclarationException(
                 "Persistent variable undeclared: " + self.expr.attr,
                 self.expr,
             )
         var = self.context.globals[self.expr.attr]
         return LLLnode.from_list(
             var.pos,
             typ=var.typ,
             location='storage',
             pos=getpos(self.expr),
             annotation='self.' + self.expr.attr,
         )
     # Reserved keywords
     elif (
         isinstance(self.expr.value, sri_ast.Name) and
         self.expr.value.id in ENVIRONMENT_VARIABLES
     ):
         key = self.expr.value.id + "." + self.expr.attr
         if key == "msg.sender":
             if self.context.is_private:
                 raise StructureException(
                     "msg.sender not allowed in private functions.", self.expr
                 )
             return LLLnode.from_list(['caller'], typ='address', pos=getpos(self.expr))
         elif key == "msg.value":
             if not self.context.is_payable:
                 raise NonPayableViolation(
                     "Cannot use msg.value in a non-payable function", self.expr,
                 )
             return LLLnode.from_list(
                 ['callvalue'],
                 typ=BaseType('uint256'),
                 pos=getpos(self.expr),
             )
         elif key == "msg.gas":
             return LLLnode.from_list(
                 ['gas'],
                 typ='uint256',
                 pos=getpos(self.expr),
             )
         elif key == "block.difficulty":
             return LLLnode.from_list(
                 ['difficulty'],
                 typ='uint256',
                 pos=getpos(self.expr),
             )
         elif key == "block.timestamp":
             return LLLnode.from_list(
                 ['timestamp'],
                 typ=BaseType('uint256'),
                 pos=getpos(self.expr),
             )
         elif key == "block.coinbase":
             return LLLnode.from_list(['coinbase'], typ='address', pos=getpos(self.expr))
         elif key == "block.number":
             return LLLnode.from_list(['number'], typ='uint256', pos=getpos(self.expr))
         elif key == "block.prevhash":
             return LLLnode.from_list(
                 ['blockhash', ['sub', 'number', 1]],
                 typ='bytes32',
                 pos=getpos(self.expr),
             )
         elif key == "tx.origin":
             return LLLnode.from_list(['origin'], typ='address', pos=getpos(self.expr))
         elif key == "chain.id":
             if not version_check(begin="istanbul"):
                 raise EvmVersionException(
                     "chain.id is unavailable prior to istanbul ruleset",
                     self.expr
                 )
             return LLLnode.from_list(['chainid'], typ='uint256', pos=getpos(self.expr))
         else:
             raise StructureException("Unsupported keyword: " + key, self.expr)
     # Other variables
     else:
         sub = Expr.parse_variable_location(self.expr.value, self.context)
         # contract type
         if isinstance(sub.typ, ContractType):
             return sub
         if not isinstance(sub.typ, StructType):
             raise TypeMismatch(
                 "Type mismatch: member variable access not expected",
                 self.expr.value,
             )
         attrs = list(sub.typ.members.keys())
         if self.expr.attr not in attrs:
             raise TypeMismatch(
                 f"Member {self.expr.attr} not found. Only the following available: "
                 f"{' '.join(attrs)}",
                 self.expr,
             )
         return add_variable_offset(sub, self.expr.attr, pos=getpos(self.expr))
예제 #25
0
def call_self_private(stmt_expr, context, sig):
    # ** Private Call **
    # Steps:
    # (x) push current local variables
    # (x) push arguments
    # (x) push jumpdest (callback ptr)
    # (x) jump to label
    # (x) pop return values
    # (x) pop local variables

    method_name, expr_args, sig = call_lookup_specs(stmt_expr, context)
    pre_init = []
    pop_local_vars = []
    push_local_vars = []
    pop_return_values = []
    push_args = []

    # Push local variables.
    var_slots = [(v.pos, v.size) for name, v in context.vars.items()
                 if v.location == 'memory']
    if var_slots:
        var_slots.sort(key=lambda x: x[0])
        mem_from, mem_to = var_slots[0][
            0], var_slots[-1][0] + var_slots[-1][1] * 32

        i_placeholder = context.new_placeholder(BaseType('uint256'))
        local_save_ident = f"_{stmt_expr.lineno}_{stmt_expr.col_offset}"
        push_loop_label = 'save_locals_start' + local_save_ident
        pop_loop_label = 'restore_locals_start' + local_save_ident

        if mem_to - mem_from > 320:
            push_local_vars = [['mstore', i_placeholder, mem_from],
                               ['label', push_loop_label],
                               ['mload', ['mload', i_placeholder]],
                               [
                                   'mstore', i_placeholder,
                                   ['add', ['mload', i_placeholder], 32]
                               ],
                               [
                                   'if',
                                   ['lt', ['mload', i_placeholder], mem_to],
                                   ['goto', push_loop_label]
                               ]]
            pop_local_vars = [['mstore', i_placeholder, mem_to - 32],
                              ['label', pop_loop_label],
                              ['mstore', ['mload', i_placeholder], 'pass'],
                              [
                                  'mstore', i_placeholder,
                                  ['sub', ['mload', i_placeholder], 32]
                              ],
                              [
                                  'if',
                                  ['ge', ['mload', i_placeholder], mem_from],
                                  ['goto', pop_loop_label]
                              ]]
        else:
            push_local_vars = [['mload', pos]
                               for pos in range(mem_from, mem_to, 32)]
            pop_local_vars = [['mstore', pos, 'pass']
                              for pos in range(mem_to - 32, mem_from - 32, -32)
                              ]

    # Push Arguments
    if expr_args:
        inargs, inargsize, arg_pos = pack_arguments(
            sig,
            expr_args,
            context,
            stmt_expr,
            return_placeholder=False,
        )
        push_args += [
            inargs
        ]  # copy arguments first, to not mess up the push/pop sequencing.

        static_arg_size = 32 * sum(
            [get_static_size_of_type(arg.typ) for arg in expr_args])
        static_pos = int(arg_pos + static_arg_size)
        needs_dyn_section = any(
            [has_dynamic_data(arg.typ) for arg in expr_args])

        if needs_dyn_section:
            ident = f'push_args_{sig.method_id}_{stmt_expr.lineno}_{stmt_expr.col_offset}'
            start_label = ident + '_start'
            end_label = ident + '_end'
            i_placeholder = context.new_placeholder(BaseType('uint256'))

            # Calculate copy start position.
            # Given | static | dynamic | section in memory,
            # copy backwards so the values are in order on the stack.
            # We calculate i, the end of the whole encoded part
            # (i.e. the starting index for copy)
            # by taking ceil32(len<arg>) + offset<arg> + arg_pos
            # for the last dynamic argument and arg_pos is the start
            # the whole argument section.
            idx = 0
            for arg in expr_args:
                if isinstance(arg.typ, ByteArrayLike):
                    last_idx = idx
                idx += get_static_size_of_type(arg.typ)
            push_args += [[
                'with', 'offset', ['mload', arg_pos + last_idx * 32],
                [
                    'with', 'len_pos', ['add', arg_pos, 'offset'],
                    [
                        'with', 'len_value', ['mload', 'len_pos'],
                        [
                            'mstore', i_placeholder,
                            ['add', 'len_pos', ['ceil32', 'len_value']]
                        ]
                    ]
                ]
            ]]
            # loop from end of dynamic section to start of dynamic section,
            # pushing each element onto the stack.
            push_args += [
                ['label', start_label],
                [
                    'if', ['lt', ['mload', i_placeholder], static_pos],
                    ['goto', end_label]
                ],
                ['mload', ['mload', i_placeholder]],
                [
                    'mstore', i_placeholder,
                    ['sub', ['mload', i_placeholder], 32]
                ],  # decrease i
                ['goto', start_label],
                ['label', end_label]
            ]

        # push static section
        push_args += [['mload', pos]
                      for pos in reversed(range(arg_pos, static_pos, 32))]
    elif sig.args:
        raise StructureException(
            f"Wrong number of args for: {sig.name} (0 args given, expected {len(sig.args)})",
            stmt_expr)

    # Jump to function label.
    jump_to_func = [
        ['add', ['pc'], 6],  # set callback pointer.
        ['goto', f'priv_{sig.method_id}'],
        ['jumpdest'],
    ]

    # Pop return values.
    returner = [0]
    if sig.output_type:
        output_placeholder, returner, output_size = call_make_placeholder(
            stmt_expr, context, sig)
        if output_size > 0:
            dynamic_offsets = []
            if isinstance(sig.output_type, (BaseType, ListType)):
                pop_return_values = [[
                    'mstore', ['add', output_placeholder, pos], 'pass'
                ] for pos in range(0, output_size, 32)]
            elif isinstance(sig.output_type, ByteArrayLike):
                dynamic_offsets = [(0, sig.output_type)]
                pop_return_values = [
                    ['pop', 'pass'],
                ]
            elif isinstance(sig.output_type, TupleLike):
                static_offset = 0
                pop_return_values = []
                for name, typ in sig.output_type.tuple_items():
                    if isinstance(typ, ByteArrayLike):
                        pop_return_values.append([
                            'mstore',
                            ['add', output_placeholder, static_offset], 'pass'
                        ])
                        dynamic_offsets.append(([
                            'mload',
                            ['add', output_placeholder, static_offset]
                        ], name))
                        static_offset += 32
                    else:
                        member_output_size = get_size_of_type(typ) * 32
                        pop_return_values.extend([[
                            'mstore', ['add', output_placeholder, pos], 'pass'
                        ] for pos in range(static_offset, static_offset +
                                           member_output_size, 32)])
                        static_offset += member_output_size

            # append dynamic unpacker.
            dyn_idx = 0
            for in_memory_offset, _out_type in dynamic_offsets:
                ident = f"{stmt_expr.lineno}_{stmt_expr.col_offset}_arg_{dyn_idx}"
                dyn_idx += 1
                start_label = 'dyn_unpack_start_' + ident
                end_label = 'dyn_unpack_end_' + ident
                i_placeholder = context.new_placeholder(
                    typ=BaseType('uint256'))
                begin_pos = ['add', output_placeholder, in_memory_offset]
                # loop until length.
                o = LLLnode.from_list(
                    [
                        'seq_unchecked',
                        ['mstore', begin_pos, 'pass'],  # get len
                        ['mstore', i_placeholder, 0],
                        ['label', start_label],
                        [  # break
                            'if',
                            [
                                'ge', ['mload', i_placeholder],
                                ['ceil32', ['mload', begin_pos]]
                            ], ['goto', end_label]
                        ],
                        [  # pop into correct memory slot.
                            'mstore',
                            [
                                'add', ['add', begin_pos, 32],
                                ['mload', i_placeholder]
                            ],
                            'pass',
                        ],
                        # increment i
                        [
                            'mstore', i_placeholder,
                            ['add', 32, ['mload', i_placeholder]]
                        ],
                        ['goto', start_label],
                        ['label', end_label]
                    ],
                    typ=None,
                    annotation='dynamic unpacker',
                    pos=getpos(stmt_expr))
                pop_return_values.append(o)

    call_body = list(
        itertools.chain(
            ['seq_unchecked'],
            pre_init,
            push_local_vars,
            push_args,
            jump_to_func,
            pop_return_values,
            pop_local_vars,
            [returner],
        ))
    # If we have no return, we need to pop off
    pop_returner_call_body = ['pop', call_body
                              ] if sig.output_type is None else call_body

    o = LLLnode.from_list(pop_returner_call_body,
                          typ=sig.output_type,
                          location='memory',
                          pos=getpos(stmt_expr),
                          annotation=f'Internal Call: {method_name}',
                          add_gas_estimate=sig.gas)
    o.gas += sig.gas
    return o
예제 #26
0
    def parse_return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatch("Not expecting to return a value",
                                   self.stmt)
            return LLLnode.from_list(
                make_return_stmt(self.stmt, self.context, 0, 0),
                typ=None,
                pos=getpos(self.stmt),
                valency=0,
            )
        if not self.stmt.value:
            raise TypeMismatch("Expecting to return a value", self.stmt)

        sub = Expr(self.stmt.value, self.context).lll_node

        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            sub = unwrap_location(sub)

            if self.context.return_type != sub.typ and not sub.typ.is_literal:
                raise TypeMismatch(
                    f"Trying to return base type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )
            elif sub.typ.is_literal and (
                    self.context.return_type.typ == sub.typ
                    or 'int' in self.context.return_type.typ
                    and 'int' in sub.typ.typ):  # noqa: E501
                if not SizeLimits.in_bounds(self.context.return_type.typ,
                                            sub.value):
                    raise InvalidLiteral(
                        "Number out of range: " + str(sub.value), self.stmt)
                else:
                    return LLLnode.from_list(
                        [
                            'seq', ['mstore', 0, sub],
                            make_return_stmt(self.stmt, self.context, 0, 32)
                        ],
                        typ=None,
                        pos=getpos(self.stmt),
                        valency=0,
                    )
            elif is_base_type(sub.typ, self.context.return_type.typ) or (
                    is_base_type(sub.typ, 'int128') and is_base_type(
                        self.context.return_type, 'int256')):  # noqa: E501
                return LLLnode.from_list(
                    [
                        'seq', ['mstore', 0, sub],
                        make_return_stmt(self.stmt, self.context, 0, 32)
                    ],
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                raise TypeMismatch(
                    f"Unsupported type conversion: {sub.typ} to {self.context.return_type}",
                    self.stmt.value,
                )
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayLike):
            if not sub.typ.eq_base(self.context.return_type):
                raise TypeMismatch(
                    f"Trying to return base type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatch(
                    f"Cannot cast from greater max-length {sub.typ.maxlen} to shorter "
                    f"max-length {self.context.return_type.maxlen}",
                    self.stmt.value,
                )

            # loop memory has to be allocated first.
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            # len & bytez placeholder have to be declared after each other at all times.
            len_placeholder = self.context.new_placeholder(
                typ=BaseType('uint256'))
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ('storage', 'memory'):
                return LLLnode.from_list([
                    'seq',
                    make_byte_array_copier(LLLnode(
                        bytez_placeholder, location='memory', typ=sub.typ),
                                           sub,
                                           pos=getpos(self.stmt)),
                    zero_pad(bytez_placeholder),
                    ['mstore', len_placeholder, 32],
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        len_placeholder,
                        ['ceil32', ['add', ['mload', bytez_placeholder], 64]],
                        loop_memory_position=loop_memory_position,
                    )
                ],
                                         typ=None,
                                         pos=getpos(self.stmt),
                                         valency=0)
            else:
                raise Exception(f"Invalid location: {sub.location}")

        elif isinstance(sub.typ, ListType):
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            if sub.typ != self.context.return_type:
                raise TypeMismatch(
                    f"List return type {sub.typ} does not match specified "
                    f"return type, expecting {self.context.return_type}",
                    self.stmt)
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    ),
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                new_sub = LLLnode.from_list(
                    self.context.new_placeholder(self.context.return_type),
                    typ=self.context.return_type,
                    location='memory',
                )
                setter = make_setter(new_sub,
                                     sub,
                                     'memory',
                                     pos=getpos(self.stmt))
                return LLLnode.from_list([
                    'seq', setter,
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        new_sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    )
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))

        # Returning a struct
        elif isinstance(sub.typ, StructType):
            retty = self.context.return_type
            if not isinstance(retty, StructType) or retty.name != sub.typ.name:
                raise TypeMismatch(
                    f"Trying to return {sub.typ}, output expecting {self.context.return_type}",
                    self.stmt.value,
                )
            return gen_tuple_return(self.stmt, self.context, sub)

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if not isinstance(self.context.return_type, TupleType):
                raise TypeMismatch(
                    f"Trying to return tuple type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )

            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!",
                                         self.stmt)

            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member,
                                                  NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    raise StructureException(
                        "Tuple return type does not match annotated return. "
                        f"{type(sub_type)} != {type(ret_x)}", self.stmt)
            return gen_tuple_return(self.stmt, self.context, sub)

        else:
            raise TypeMismatch(f"Can't return type {sub.typ}", self.stmt)
예제 #27
0
def test_tuple_node_types():
    node1 = TupleType([BaseType('int128'), BaseType('decimal')])
    node2 = TupleType([BaseType('int128'), BaseType('decimal')])

    assert node1 == node2
    assert str(node1) == "(int128, decimal)"
예제 #28
0
    def build_in_comparator(self):
        left = Expr(self.expr.left, self.context).lll_node
        right = Expr(self.expr.right, self.context).lll_node

        if left.typ != right.typ.subtype:
            raise TypeMismatch(
                f"{left.typ} cannot be in a list of {right.typ.subtype}",
                self.expr,
            )

        result_placeholder = self.context.new_placeholder(BaseType('bool'))
        setter = []

        # Load nth item from list in memory.
        if right.value == 'multi':
            # Copy literal to memory to be compared.
            tmp_list = LLLnode.from_list(
                obj=self.context.new_placeholder(ListType(right.typ.subtype, right.typ.count)),
                typ=ListType(right.typ.subtype, right.typ.count),
                location='memory'
            )
            setter = make_setter(tmp_list, right, 'memory', pos=getpos(self.expr))
            load_i_from_list = [
                'mload',
                ['add', tmp_list, ['mul', 32, ['mload', MemoryPositions.FREE_LOOP_INDEX]]],
            ]
        elif right.location == "storage":
            load_i_from_list = [
                'sload',
                ['add', ['sha3_32', right], ['mload', MemoryPositions.FREE_LOOP_INDEX]],
            ]
        else:
            load_i_from_list = [
                'mload',
                ['add', right, ['mul', 32, ['mload', MemoryPositions.FREE_LOOP_INDEX]]],
            ]

        # Condition repeat loop has to break on.
        break_loop_condition = [
            'if',
            ['eq', unwrap_location(left), load_i_from_list],
            ['seq',
                ['mstore', '_result', 1],  # store true.
                'break']
        ]

        # Repeat loop to loop-compare each item in the list.
        for_loop_sequence = [
            ['mstore', result_placeholder, 0],
            ['with', '_result', result_placeholder, [
                'repeat',
                MemoryPositions.FREE_LOOP_INDEX,
                0,
                right.typ.count,
                break_loop_condition,
            ]],
            ['mload', result_placeholder]
        ]

        # Save list to memory, so one can iterate over it,
        # used when literal was created with tmp_list.
        if setter:
            compare_sequence = ['seq', setter] + for_loop_sequence
        else:
            compare_sequence = ['seq'] + for_loop_sequence

        # Compare the result of the repeat loop to 1, to know if a match was found.
        o = LLLnode.from_list([
            'eq', 1,
            compare_sequence],
            typ='bool',
            annotation="in comporator"
        )

        return o
예제 #29
0
 def _is_valid_contract_assign(self):
     if self.expr.args and len(self.expr.args) == 1:
         arg_lll = Expr(self.expr.args[0], self.context).lll_node
         if arg_lll.typ == BaseType('address'):
             return True, arg_lll
     return False, None
예제 #30
0
    def parse_for_list(self):
        with self.context.range_scope():
            iter_list_node = Expr(self.stmt.iter, self.context).lll_node
        if not isinstance(iter_list_node.typ.subtype,
                          BaseType):  # Sanity check on list subtype.
            raise StructureException(
                'For loops allowed only on basetype lists.', self.stmt.iter)
        iter_var_type = (self.context.vars.get(self.stmt.iter.id).typ
                         if isinstance(self.stmt.iter, sri_ast.Name) else None)
        subtype = iter_list_node.typ.subtype.typ
        varname = self.stmt.target.id
        value_pos = self.context.new_variable(
            varname,
            BaseType(subtype),
        )
        i_pos_raw_name = '_index_for_' + varname
        i_pos = self.context.new_internal_variable(
            i_pos_raw_name,
            BaseType(subtype),
        )
        self.context.forvars[varname] = True

        # Is a list that is already allocated to memory.
        if iter_var_type:

            list_name = self.stmt.iter.id
            # make sure list cannot be altered whilst iterating.
            with self.context.in_for_loop_scope(list_name):
                iter_var = self.context.vars.get(self.stmt.iter.id)
                if iter_var.location == 'calldata':
                    fetcher = 'calldataload'
                elif iter_var.location == 'memory':
                    fetcher = 'mload'
                else:
                    raise CompilerPanic(
                        f'List iteration only supported on in-memory types {self.expr}',
                    )
                body = [
                    'seq',
                    [
                        'mstore',
                        value_pos,
                        [
                            fetcher,
                            [
                                'add', iter_var.pos,
                                ['mul', ['mload', i_pos], 32]
                            ]
                        ],
                    ],
                    parse_body(self.stmt.body, self.context)
                ]
                o = LLLnode.from_list(
                    ['repeat', i_pos, 0, iter_var.size, body],
                    typ=None,
                    pos=getpos(self.stmt))

        # List gets defined in the for statement.
        elif isinstance(self.stmt.iter, sri_ast.List):
            # Allocate list to memory.
            count = iter_list_node.typ.count
            tmp_list = LLLnode.from_list(obj=self.context.new_placeholder(
                ListType(iter_list_node.typ.subtype, count)),
                                         typ=ListType(
                                             iter_list_node.typ.subtype,
                                             count),
                                         location='memory')
            setter = make_setter(tmp_list,
                                 iter_list_node,
                                 'memory',
                                 pos=getpos(self.stmt))
            body = [
                'seq',
                [
                    'mstore', value_pos,
                    [
                        'mload',
                        ['add', tmp_list, ['mul', ['mload', i_pos], 32]]
                    ]
                ],
                parse_body(self.stmt.body, self.context)
            ]
            o = LLLnode.from_list(
                ['seq', setter, ['repeat', i_pos, 0, count, body]],
                typ=None,
                pos=getpos(self.stmt))

        # List contained in storage.
        elif isinstance(self.stmt.iter, sri_ast.Attribute):
            count = iter_list_node.typ.count
            list_name = iter_list_node.annotation

            # make sure list cannot be altered whilst iterating.
            with self.context.in_for_loop_scope(list_name):
                body = [
                    'seq',
                    [
                        'mstore', value_pos,
                        [
                            'sload',
                            [
                                'add', ['sha3_32', iter_list_node],
                                ['mload', i_pos]
                            ]
                        ]
                    ],
                    parse_body(self.stmt.body, self.context),
                ]
                o = LLLnode.from_list(
                    ['seq', ['repeat', i_pos, 0, count, body]],
                    typ=None,
                    pos=getpos(self.stmt))

        del self.context.vars[varname]
        # this kind of open access to the vars dict should be disallowed.
        # we should use member functions to provide an API for these kinds
        # of operations.
        del self.context.vars[self.context._mangle(i_pos_raw_name)]
        del self.context.forvars[varname]
        return o