Exemple #1
0
def to_uint256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'num_literal':
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds('uint256', in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType('uint256', ),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds('uint256', math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType('uint256'),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type == 'int128':
        return LLLnode.from_list(['clampge', in_arg, 0],
                                 typ=BaseType('uint256'),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == 'decimal':
        return LLLnode.from_list(
            ['div', ['clampge', in_arg, 0], DECIMAL_DIVISOR],
            typ=BaseType('uint256'),
            pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == 'bool':
        return LLLnode.from_list(in_arg,
                                 typ=BaseType('uint256'),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ('bytes32', 'address'):
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType('uint256'),
                       pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise InvalidLiteral(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to uint256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, 'uint256')

    else:
        raise InvalidLiteral(f"Invalid input for uint256: {in_arg}", expr)
Exemple #2
0
    def _assert_reason(self, test_expr, msg):
        if isinstance(msg, sri_ast.Name) and msg.id == 'UNREACHABLE':
            return self._assert_unreachable(test_expr, msg)

        if not isinstance(msg, sri_ast.Str):
            raise StructureException(
                'Reason parameter of assert needs to be a literal string '
                '(or UNREACHABLE constant).', msg)
        if len(msg.s.strip()) == 0:
            raise StructureException('Empty reason string not allowed.',
                                     self.stmt)
        reason_str = msg.s.strip()
        sig_placeholder = self.context.new_placeholder(BaseType(32))
        arg_placeholder = self.context.new_placeholder(BaseType(32))
        reason_str_type = ByteArrayType(len(reason_str))
        placeholder_bytes = Expr(msg, self.context).lll_node
        method_id = fourbytes_to_int(keccak256(b"Error(string)")[:4])
        assert_reason = [
            'seq',
            ['mstore', sig_placeholder, method_id],
            ['mstore', arg_placeholder, 32],
            placeholder_bytes,
            [
                'assert_reason',
                test_expr,
                int(sig_placeholder + 28),
                int(4 + get_size_of_type(reason_str_type) * 32),
            ],
        ]
        return LLLnode.from_list(assert_reason,
                                 typ=None,
                                 pos=getpos(self.stmt))
Exemple #3
0
def parse_body(code, context):
    if not isinstance(code, list):
        return parse_stmt(code, context)

    o = ['seq']
    for stmt in code:
        lll = parse_stmt(stmt, context)
        o.append(lll)
    o.append('pass')  # force zerovalent, even last statement
    return LLLnode.from_list(o, pos=getpos(code[0]) if code else None)
Exemple #4
0
def to_bool(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to bool",
                expr,
            )
        else:
            num = byte_array_to_num(in_arg, expr, 'uint256')
            return LLLnode.from_list(['iszero', ['iszero', num]],
                                     typ=BaseType('bool'),
                                     pos=getpos(expr))

    else:
        return LLLnode.from_list(['iszero', ['iszero', in_arg]],
                                 typ=BaseType('bool'),
                                 pos=getpos(expr))
Exemple #5
0
def to_bytes32(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _len = get_type(in_arg)

    if input_type == 'bytes':
        if _len > 32:
            raise TypeMismatch(
                f"Unable to convert bytes[{_len}] to bytes32, max length is too "
                "large.")

        if in_arg.location == "memory":
            return LLLnode.from_list(['mload', ['add', in_arg, 32]],
                                     typ=BaseType('bytes32'))
        elif in_arg.location == "storage":
            return LLLnode.from_list(
                ['sload', ['add', ['sha3_32', in_arg], 1]],
                typ=BaseType('bytes32'))

    else:
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType('bytes32'),
                       pos=getpos(expr))
Exemple #6
0
    def parse_assert(self):

        with self.context.assertion_scope():
            test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

        if not self.is_bool_expr(test_expr):
            raise TypeMismatch('Only boolean expressions allowed',
                               self.stmt.test)
        if self.stmt.msg:
            return self._assert_reason(test_expr, self.stmt.msg)
        else:
            return LLLnode.from_list(['assert', test_expr],
                                     typ=None,
                                     pos=getpos(self.stmt))
Exemple #7
0
 def lll_compiler(lll, *args, **kwargs):
     lll = optimizer.optimize(LLLnode.from_list(lll))
     bytecode, _ = compile_lll.assembly_to_evm(
         compile_lll.compile_to_assembly(lll))
     abi = kwargs.get('abi') or []
     c = w3.eth.contract(abi=abi, bytecode=bytecode)
     deploy_transaction = c.constructor()
     tx_hash = deploy_transaction.transact()
     address = w3.eth.getTransactionReceipt(tx_hash)['contractAddress']
     contract = w3.eth.contract(
         address,
         abi=abi,
         bytecode=bytecode,
         ContractFactoryClass=srilangContract,
     )
     return contract
Exemple #8
0
    def ann_assign(self):
        with self.context.assignment_scope():
            typ = parse_type(
                self.stmt.annotation,
                location='memory',
                custom_structs=self.context.structs,
                constants=self.context.constants,
            )
            if isinstance(self.stmt.target, sri_ast.Attribute):
                raise TypeMismatch(
                    f'May not set type for field {self.stmt.target.attr}',
                    self.stmt,
                )
            varname = self.stmt.target.id
            pos = self.context.new_variable(varname, typ)
            if self.stmt.value is None:
                raise StructureException(
                    'New variables must be initialized explicitly', self.stmt)

            sub = Expr(self.stmt.value, self.context).lll_node

            is_literal_bytes32_assign = (isinstance(sub.typ, ByteArrayType)
                                         and sub.typ.maxlen == 32
                                         and isinstance(typ, BaseType)
                                         and typ.typ == 'bytes32'
                                         and sub.typ.is_literal)

            # If bytes[32] to bytes32 assignment rewrite sub as bytes32.
            if is_literal_bytes32_assign:
                sub = LLLnode(
                    bytes_to_int(self.stmt.value.s),
                    typ=BaseType('bytes32'),
                    pos=getpos(self.stmt),
                )

            self._check_valid_assign(sub)
            self._check_same_variable_assign(sub)
            variable_loc = LLLnode.from_list(
                pos,
                typ=typ,
                location='memory',
                pos=getpos(self.stmt),
            )
            o = make_setter(variable_loc, sub, 'memory', pos=getpos(self.stmt))

            return o
Exemple #9
0
    def parse_if(self):
        if self.stmt.orelse:
            block_scope_id = id(self.stmt.orelse)
            with self.context.make_blockscope(block_scope_id):
                add_on = [parse_body(self.stmt.orelse, self.context)]
        else:
            add_on = []

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

            if not self.is_bool_expr(test_expr):
                raise TypeMismatch('Only boolean expressions allowed',
                                   self.stmt.test)
            body = ['if', test_expr,
                    parse_body(self.stmt.body, self.context)] + add_on
            o = LLLnode.from_list(body, typ=None, pos=getpos(self.stmt))
        return o
Exemple #10
0
def compile_to_lll(input_file, output_formats, show_gas_estimates=False):
    with open(input_file) as fh:
        s_expressions = parse_s_exp(fh.read())

    if show_gas_estimates:
        LLLnode.repr_show_gas = True

    compiler_data = {}
    lll = LLLnode.from_list(s_expressions[0])
    if 'ir' in output_formats:
        compiler_data['ir'] = lll

    if 'opt_ir' in output_formats:
        compiler_data['opt_ir'] = optimizer.optimize(lll)

    asm = compile_lll.compile_to_assembly(lll)
    if 'asm' in output_formats:
        compiler_data['asm'] = asm

    if 'bytecode' in output_formats:
        (bytecode, _srcmap) = compile_lll.assembly_to_evm(asm)
        compiler_data['bytecode'] = '0x' + bytecode.hex()

    return compiler_data
Exemple #11
0
def optimize(node: LLLnode) -> LLLnode:
    argz = [optimize(arg) for arg in node.args]
    if node.value in arith and int_at(argz, 0) and int_at(argz, 1):
        left, right = get_int_at(argz, 0), get_int_at(argz, 1)
        # `node.value in arith` implies that `node.value` is a `str`
        calcer, symb = arith[str(node.value)]
        new_value = calcer(left, right)
        if argz[0].annotation and argz[1].annotation:
            annotation = argz[0].annotation + symb + argz[1].annotation
        elif argz[0].annotation or argz[1].annotation:
            annotation = (
                argz[0].annotation or str(left)
            ) + symb + (
                argz[1].annotation or str(right)
            )
        else:
            annotation = ''
        return LLLnode(
            new_value,
            [],
            node.typ,
            None,
            node.pos,
            annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif _is_constant_add(node, argz):
        # `node.value in arith` implies that `node.value` is a `str`
        calcer, symb = arith[str(node.value)]
        if argz[0].annotation and argz[1].args[0].annotation:
            annotation = argz[0].annotation + symb + argz[1].args[0].annotation
        elif argz[0].annotation or argz[1].args[0].annotation:
            annotation = (
                argz[0].annotation or str(argz[0].value)
            ) + symb + (
                argz[1].args[0].annotation or str(argz[1].args[0].value)
            )
        else:
            annotation = ''
        return LLLnode(
            "add",
            [
                LLLnode(int(argz[0].value) + int(argz[1].args[0].value), annotation=annotation),
                argz[1].args[1],
            ],
            node.typ,
            None,
            annotation=node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.value == "add" and get_int_at(argz, 0) == 0:
        return LLLnode(
            argz[1].value,
            argz[1].args,
            node.typ,
            node.location,
            node.pos,
            annotation=argz[1].annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.value == "add" and get_int_at(argz, 1) == 0:
        return LLLnode(
            argz[0].value,
            argz[0].args,
            node.typ,
            node.location,
            node.pos,
            argz[0].annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.value == "clamp" and int_at(argz, 0) and int_at(argz, 1) and int_at(argz, 2):
        if get_int_at(argz, 0, True) > get_int_at(argz, 1, True):  # type: ignore
            raise Exception("Clamp always fails")
        elif get_int_at(argz, 1, True) > get_int_at(argz, 2, True):  # type: ignore
            raise Exception("Clamp always fails")
        else:
            return argz[1]
    elif node.value == "clamp" and int_at(argz, 0) and int_at(argz, 1):
        if get_int_at(argz, 0, True) > get_int_at(argz, 1, True):  # type: ignore
            raise Exception("Clamp always fails")
        else:
            return LLLnode(
                "clample",
                [argz[1], argz[2]],
                node.typ,
                node.location,
                node.pos,
                node.annotation,
                add_gas_estimate=node.add_gas_estimate,
                valency=node.valency,
            )
    elif node.value == "clamp_nonzero" and int_at(argz, 0):
        if get_int_at(argz, 0) != 0:
            return LLLnode(
                argz[0].value,
                [],
                node.typ,
                node.location,
                node.pos,
                node.annotation,
                add_gas_estimate=node.add_gas_estimate,
                valency=node.valency,
            )
        else:
            raise Exception("Clamp always fails")
    # [eq, x, 0] is the same as [iszero, x].
    elif node.value == 'eq' and int_at(argz, 1) and argz[1].value == 0:
        return LLLnode(
            'iszero',
            [argz[0]],
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    # [ne, x, y] has the same truthyness as [xor, x, y]
    # rewrite 'ne' as 'xor' in places where truthy is accepted.
    elif has_cond_arg(node) and argz[0].value == 'ne':
        argz[0] = LLLnode.from_list(['xor'] + argz[0].args)  # type: ignore
        return LLLnode.from_list(
                [node.value] + argz,  # type: ignore
                typ=node.typ,
                location=node.location,
                pos=node.pos,
                annotation=node.annotation,
                # let from_list handle valency and gas_estimate
                )
    elif node.value == "seq":
        xs: List[Any] = []
        for arg in argz:
            if arg.value == "seq":
                xs.extend(arg.args)
            else:
                xs.append(arg)
        return LLLnode(
            node.value,
            xs,
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.total_gas is not None:
        o = LLLnode(
            node.value,
            argz,
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
        o.total_gas = node.total_gas - node.gas + o.gas
        o.func_name = node.func_name
        return o
    else:
        return LLLnode(
            node.value,
            argz,
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
Exemple #12
0
def to_decimal(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to decimal",
                expr,
            )
        num = byte_array_to_num(in_arg, expr, 'int128')
        return LLLnode.from_list(['mul', num, DECIMAL_DIVISOR],
                                 typ=BaseType('decimal'),
                                 pos=getpos(expr))

    else:
        if input_type == 'uint256':
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds('int128',
                                            (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteral(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(['mul', in_arg, DECIMAL_DIVISOR],
                                             typ=BaseType('decimal'),
                                             pos=getpos(expr))
            else:
                return LLLnode.from_list([
                    'uclample', ['mul', in_arg, DECIMAL_DIVISOR],
                    ['mload', MemoryPositions.MAXDECIMAL]
                ],
                                         typ=BaseType('decimal'),
                                         pos=getpos(expr))

        elif input_type == 'address':
            return LLLnode.from_list([
                'mul',
                [
                    'signextend',
                    15,
                    ['and', in_arg, (SizeLimits.ADDRSIZE - 1)],
                ], DECIMAL_DIVISOR
            ],
                                     typ=BaseType('decimal'),
                                     pos=getpos(expr))

        elif input_type == 'bytes32':
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds('int128',
                                            (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteral(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(['mul', in_arg, DECIMAL_DIVISOR],
                                             typ=BaseType('decimal'),
                                             pos=getpos(expr))
            else:
                return LLLnode.from_list([
                    'clamp',
                    ['mload', MemoryPositions.MINDECIMAL],
                    ['mul', in_arg, DECIMAL_DIVISOR],
                    ['mload', MemoryPositions.MAXDECIMAL],
                ],
                                         typ=BaseType('decimal'),
                                         pos=getpos(expr))

        elif input_type in ('int128', 'bool'):
            return LLLnode.from_list(['mul', in_arg, DECIMAL_DIVISOR],
                                     typ=BaseType('decimal'),
                                     pos=getpos(expr))

        else:
            raise InvalidLiteral(f"Invalid input for decimal: {in_arg}", expr)
Exemple #13
0
    def parse_return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatch("Not expecting to return a value",
                                   self.stmt)
            return LLLnode.from_list(
                make_return_stmt(self.stmt, self.context, 0, 0),
                typ=None,
                pos=getpos(self.stmt),
                valency=0,
            )
        if not self.stmt.value:
            raise TypeMismatch("Expecting to return a value", self.stmt)

        sub = Expr(self.stmt.value, self.context).lll_node

        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            sub = unwrap_location(sub)

            if self.context.return_type != sub.typ and not sub.typ.is_literal:
                raise TypeMismatch(
                    f"Trying to return base type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )
            elif sub.typ.is_literal and (
                    self.context.return_type.typ == sub.typ
                    or 'int' in self.context.return_type.typ
                    and 'int' in sub.typ.typ):  # noqa: E501
                if not SizeLimits.in_bounds(self.context.return_type.typ,
                                            sub.value):
                    raise InvalidLiteral(
                        "Number out of range: " + str(sub.value), self.stmt)
                else:
                    return LLLnode.from_list(
                        [
                            'seq', ['mstore', 0, sub],
                            make_return_stmt(self.stmt, self.context, 0, 32)
                        ],
                        typ=None,
                        pos=getpos(self.stmt),
                        valency=0,
                    )
            elif is_base_type(sub.typ, self.context.return_type.typ) or (
                    is_base_type(sub.typ, 'int128') and is_base_type(
                        self.context.return_type, 'int256')):  # noqa: E501
                return LLLnode.from_list(
                    [
                        'seq', ['mstore', 0, sub],
                        make_return_stmt(self.stmt, self.context, 0, 32)
                    ],
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                raise TypeMismatch(
                    f"Unsupported type conversion: {sub.typ} to {self.context.return_type}",
                    self.stmt.value,
                )
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayLike):
            if not sub.typ.eq_base(self.context.return_type):
                raise TypeMismatch(
                    f"Trying to return base type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatch(
                    f"Cannot cast from greater max-length {sub.typ.maxlen} to shorter "
                    f"max-length {self.context.return_type.maxlen}",
                    self.stmt.value,
                )

            # loop memory has to be allocated first.
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            # len & bytez placeholder have to be declared after each other at all times.
            len_placeholder = self.context.new_placeholder(
                typ=BaseType('uint256'))
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ('storage', 'memory'):
                return LLLnode.from_list([
                    'seq',
                    make_byte_array_copier(LLLnode(
                        bytez_placeholder, location='memory', typ=sub.typ),
                                           sub,
                                           pos=getpos(self.stmt)),
                    zero_pad(bytez_placeholder),
                    ['mstore', len_placeholder, 32],
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        len_placeholder,
                        ['ceil32', ['add', ['mload', bytez_placeholder], 64]],
                        loop_memory_position=loop_memory_position,
                    )
                ],
                                         typ=None,
                                         pos=getpos(self.stmt),
                                         valency=0)
            else:
                raise Exception(f"Invalid location: {sub.location}")

        elif isinstance(sub.typ, ListType):
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            if sub.typ != self.context.return_type:
                raise TypeMismatch(
                    f"List return type {sub.typ} does not match specified "
                    f"return type, expecting {self.context.return_type}",
                    self.stmt)
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    ),
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                new_sub = LLLnode.from_list(
                    self.context.new_placeholder(self.context.return_type),
                    typ=self.context.return_type,
                    location='memory',
                )
                setter = make_setter(new_sub,
                                     sub,
                                     'memory',
                                     pos=getpos(self.stmt))
                return LLLnode.from_list([
                    'seq', setter,
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        new_sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    )
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))

        # Returning a struct
        elif isinstance(sub.typ, StructType):
            retty = self.context.return_type
            if not isinstance(retty, StructType) or retty.name != sub.typ.name:
                raise TypeMismatch(
                    f"Trying to return {sub.typ}, output expecting {self.context.return_type}",
                    self.stmt.value,
                )
            return gen_tuple_return(self.stmt, self.context, sub)

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if not isinstance(self.context.return_type, TupleType):
                raise TypeMismatch(
                    f"Trying to return tuple type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )

            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!",
                                         self.stmt)

            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member,
                                                  NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    raise StructureException(
                        "Tuple return type does not match annotated return. "
                        f"{type(sub_type)} != {type(ret_x)}", self.stmt)
            return gen_tuple_return(self.stmt, self.context, sub)

        else:
            raise TypeMismatch(f"Can't return type {sub.typ}", self.stmt)
Exemple #14
0
 def parse_break(self):
     return LLLnode.from_list('break', typ=None, pos=getpos(self.stmt))
Exemple #15
0
 def parse_continue(self):
     return LLLnode.from_list('continue', typ=None, pos=getpos(self.stmt))
Exemple #16
0
 def parse_pass(self):
     return LLLnode.from_list('pass', typ=None, pos=getpos(self.stmt))
Exemple #17
0
 def aug_assign(self):
     target = self.get_target(self.stmt.target)
     sub = Expr.parse_value_expr(self.stmt.value, self.context)
     if not isinstance(self.stmt.op,
                       (sri_ast.Add, sri_ast.Sub, sri_ast.Mult, sri_ast.Div,
                        sri_ast.Mod)):
         raise StructureException("Unsupported operator for augassign",
                                  self.stmt)
     if not isinstance(target.typ, BaseType):
         raise TypeMismatch(
             "Can only use aug-assign operators with simple types!",
             self.stmt.target)
     if target.location == 'storage':
         o = Expr.parse_value_expr(
             sri_ast.BinOp(
                 left=LLLnode.from_list(['sload', '_stloc'],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list([
             'with',
             '_stloc',
             target,
             [
                 'sstore',
                 '_stloc',
                 base_type_conversion(
                     o, o.typ, target.typ, pos=getpos(self.stmt)),
             ],
         ],
                                  typ=None,
                                  pos=getpos(self.stmt))
     elif target.location == 'memory':
         o = Expr.parse_value_expr(
             sri_ast.BinOp(
                 left=LLLnode.from_list(['mload', '_mloc'],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list([
             'with',
             '_mloc',
             target,
             [
                 'mstore',
                 '_mloc',
                 base_type_conversion(
                     o, o.typ, target.typ, pos=getpos(self.stmt)),
             ],
         ],
                                  typ=None,
                                  pos=getpos(self.stmt))
Exemple #18
0
    def parse_for_list(self):
        with self.context.range_scope():
            iter_list_node = Expr(self.stmt.iter, self.context).lll_node
        if not isinstance(iter_list_node.typ.subtype,
                          BaseType):  # Sanity check on list subtype.
            raise StructureException(
                'For loops allowed only on basetype lists.', self.stmt.iter)
        iter_var_type = (self.context.vars.get(self.stmt.iter.id).typ
                         if isinstance(self.stmt.iter, sri_ast.Name) else None)
        subtype = iter_list_node.typ.subtype.typ
        varname = self.stmt.target.id
        value_pos = self.context.new_variable(
            varname,
            BaseType(subtype),
        )
        i_pos_raw_name = '_index_for_' + varname
        i_pos = self.context.new_internal_variable(
            i_pos_raw_name,
            BaseType(subtype),
        )
        self.context.forvars[varname] = True

        # Is a list that is already allocated to memory.
        if iter_var_type:

            list_name = self.stmt.iter.id
            # make sure list cannot be altered whilst iterating.
            with self.context.in_for_loop_scope(list_name):
                iter_var = self.context.vars.get(self.stmt.iter.id)
                if iter_var.location == 'calldata':
                    fetcher = 'calldataload'
                elif iter_var.location == 'memory':
                    fetcher = 'mload'
                else:
                    raise CompilerPanic(
                        f'List iteration only supported on in-memory types {self.expr}',
                    )
                body = [
                    'seq',
                    [
                        'mstore',
                        value_pos,
                        [
                            fetcher,
                            [
                                'add', iter_var.pos,
                                ['mul', ['mload', i_pos], 32]
                            ]
                        ],
                    ],
                    parse_body(self.stmt.body, self.context)
                ]
                o = LLLnode.from_list(
                    ['repeat', i_pos, 0, iter_var.size, body],
                    typ=None,
                    pos=getpos(self.stmt))

        # List gets defined in the for statement.
        elif isinstance(self.stmt.iter, sri_ast.List):
            # Allocate list to memory.
            count = iter_list_node.typ.count
            tmp_list = LLLnode.from_list(obj=self.context.new_placeholder(
                ListType(iter_list_node.typ.subtype, count)),
                                         typ=ListType(
                                             iter_list_node.typ.subtype,
                                             count),
                                         location='memory')
            setter = make_setter(tmp_list,
                                 iter_list_node,
                                 'memory',
                                 pos=getpos(self.stmt))
            body = [
                'seq',
                [
                    'mstore', value_pos,
                    [
                        'mload',
                        ['add', tmp_list, ['mul', ['mload', i_pos], 32]]
                    ]
                ],
                parse_body(self.stmt.body, self.context)
            ]
            o = LLLnode.from_list(
                ['seq', setter, ['repeat', i_pos, 0, count, body]],
                typ=None,
                pos=getpos(self.stmt))

        # List contained in storage.
        elif isinstance(self.stmt.iter, sri_ast.Attribute):
            count = iter_list_node.typ.count
            list_name = iter_list_node.annotation

            # make sure list cannot be altered whilst iterating.
            with self.context.in_for_loop_scope(list_name):
                body = [
                    'seq',
                    [
                        'mstore', value_pos,
                        [
                            'sload',
                            [
                                'add', ['sha3_32', iter_list_node],
                                ['mload', i_pos]
                            ]
                        ]
                    ],
                    parse_body(self.stmt.body, self.context),
                ]
                o = LLLnode.from_list(
                    ['seq', ['repeat', i_pos, 0, count, body]],
                    typ=None,
                    pos=getpos(self.stmt))

        del self.context.vars[varname]
        # this kind of open access to the vars dict should be disallowed.
        # we should use member functions to provide an API for these kinds
        # of operations.
        del self.context.vars[self.context._mangle(i_pos_raw_name)]
        del self.context.forvars[varname]
        return o
Exemple #19
0
def to_int128(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'num_literal':
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds('int128', in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType('int128'),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds('int128', math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType('int128'),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif input_type == 'bytes32':
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds('int128', in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType('int128'),
                                         pos=getpos(expr))
        else:
            return LLLnode.from_list([
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                in_arg,
                ['mload', MemoryPositions.MAXNUM],
            ],
                                     typ=BaseType('int128'),
                                     pos=getpos(expr))

    elif input_type == 'address':
        return LLLnode.from_list([
            'signextend',
            15,
            ['and', in_arg, (SizeLimits.ADDRSIZE - 1)],
        ],
                                 typ=BaseType('int128'),
                                 pos=getpos(expr))

    elif input_type in ('string', 'bytes'):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128",
                expr,
            )
        return byte_array_to_num(in_arg, expr, 'int128')

    elif input_type == 'uint256':
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds('int128', in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType('int128'),
                                         pos=getpos(expr))

        else:
            return LLLnode.from_list(
                ['uclample', in_arg, ['mload', MemoryPositions.MAXNUM]],
                typ=BaseType('int128'),
                pos=getpos(expr))

    elif input_type == 'decimal':
        return LLLnode.from_list([
            'clamp',
            ['mload', MemoryPositions.MINNUM],
            ['sdiv', in_arg, DECIMAL_DIVISOR],
            ['mload', MemoryPositions.MAXNUM],
        ],
                                 typ=BaseType('int128'),
                                 pos=getpos(expr))

    elif input_type == 'bool':
        return LLLnode.from_list(in_arg,
                                 typ=BaseType('int128'),
                                 pos=getpos(expr))

    else:
        raise InvalidLiteral(f"Invalid input for int128: {in_arg}", expr)
Exemple #20
0
 def _assert_unreachable(test_expr, msg):
     return LLLnode.from_list(['assert_unreachable', test_expr],
                              typ=None,
                              pos=getpos(msg))
Exemple #21
0
    def call(self):
        is_self_function = (isinstance(
            self.stmt.func, sri_ast.Attribute)) and isinstance(
                self.stmt.func.value,
                sri_ast.Name) and self.stmt.func.value.id == "self"

        is_log_call = (isinstance(
            self.stmt.func, sri_ast.Attribute)) and isinstance(
                self.stmt.func.value,
                sri_ast.Name) and self.stmt.func.value.id == 'log'

        if isinstance(self.stmt.func, sri_ast.Name):
            funcname = self.stmt.func.id
            if funcname in STMT_DISPATCH_TABLE:
                return STMT_DISPATCH_TABLE[funcname].build_LLL(
                    self.stmt, self.context)
            elif funcname in DISPATCH_TABLE:
                raise StructureException(
                    f"Function {funcname} can not be called without being used.",
                    self.stmt,
                )
            else:
                raise StructureException(
                    f"Unknown function: '{self.stmt.func.id}'.",
                    self.stmt,
                )
        elif is_self_function:
            return self_call.make_call(self.stmt, self.context)
        elif is_log_call:
            if self.stmt.func.attr not in self.context.sigs['self']:
                raise EventDeclarationException(
                    f"Event not declared yet: {self.stmt.func.attr}")
            event = self.context.sigs['self'][self.stmt.func.attr]
            if len(event.indexed_list) != len(self.stmt.args):
                raise EventDeclarationException(
                    f"{event.name} received {len(self.stmt.args)} arguments but "
                    f"expected {len(event.indexed_list)}")
            expected_topics, topics = [], []
            expected_data, data = [], []
            for pos, is_indexed in enumerate(event.indexed_list):
                if is_indexed:
                    expected_topics.append(event.args[pos])
                    topics.append(self.stmt.args[pos])
                else:
                    expected_data.append(event.args[pos])
                    data.append(self.stmt.args[pos])
            topics = pack_logging_topics(
                event.event_id,
                topics,
                expected_topics,
                self.context,
                pos=getpos(self.stmt),
            )
            inargs, inargsize, inargsize_node, inarg_start = pack_logging_data(
                expected_data,
                data,
                self.context,
                pos=getpos(self.stmt),
            )

            if inargsize_node is None:
                sz = inargsize
            else:
                sz = ['mload', inargsize_node]

            return LLLnode.from_list([
                'seq', inargs,
                LLLnode.from_list(
                    ["log" + str(len(topics)), inarg_start, sz] + topics,
                    add_gas_estimate=inargsize * 10,
                )
            ],
                                     typ=None,
                                     pos=getpos(self.stmt))
        else:
            return external_call.make_external_call(self.stmt, self.context)
Exemple #22
0
def test_sha3_32():
    lll = ['sha3_32', 0]
    evm = ['PUSH1', 0, 'PUSH1', 192, 'MSTORE', 'PUSH1', 32, 'PUSH1', 192, 'SHA3']
    assert compile_lll.compile_to_assembly(LLLnode.from_list(lll)) == evm
    assert compile_lll.compile_to_assembly(optimizer.optimize(LLLnode.from_list(lll))) == evm
Exemple #23
0
    def parse_for(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, sri_ast.Call):
            if isinstance(self.stmt.iter, sri_ast.Subscript):
                raise StructureException("Cannot iterate over a nested list",
                                         self.stmt.iter)
            raise StructureException(
                f"Cannot iterate over '{type(self.stmt.iter).__name__}' object",
                self.stmt.iter)
        if getattr(self.stmt.iter.func, 'id', None) != "range":
            raise StructureException(
                "Non-literals cannot be used as loop range",
                self.stmt.iter.func)
        if len(self.stmt.iter.args) not in {1, 2}:
            raise StructureException(
                f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}",
                self.stmt.iter.func)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, sri_ast.BinOp) or not isinstance(
                        arg1.op, sri_ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if not sri_ast.compare_nodes(arg0, arg1.left):
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {sri_ast.ast_to_dict(arg0)} {sri_ast.ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                raise StructureException(
                    f"For loop has invalid number of iterations ({r}),"
                    " the value must be greater than zero", self.stmt.iter)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o