Exemplo n.º 1
0
def get_external_contract_keywords(stmt_expr, context):
    from srilang.parser.expr import Expr
    value, gas = None, None
    for kw in stmt_expr.keywords:
        if kw.arg not in ('value', 'gas'):
            raise TypeMismatch(
                'Invalid keyword argument, only "gas" and "value" supported.',
                stmt_expr,
            )
        elif kw.arg == 'gas':
            gas = Expr.parse_value_expr(kw.value, context)
        elif kw.arg == 'value':
            value = Expr.parse_value_expr(kw.value, context)
    return value, gas
Exemplo n.º 2
0
    def _check_valid_range_constant(self, arg_ast_node, raise_exception=True):
        with self.context.range_scope():
            # TODO should catch if raise_exception == False?
            arg_expr = Expr.parse_value_expr(arg_ast_node, self.context)

        is_integer_literal = (isinstance(arg_expr.typ, BaseType)
                              and arg_expr.typ.is_literal
                              and arg_expr.typ.typ in {'uint256', 'int128'})
        if not is_integer_literal and raise_exception:
            raise StructureException(
                "Range only accepts literal (constant) values of type uint256 or int128",
                arg_ast_node)
        return is_integer_literal, arg_expr
Exemplo n.º 3
0
    def parse_assert(self):

        with self.context.assertion_scope():
            test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

        if not self.is_bool_expr(test_expr):
            raise TypeMismatch('Only boolean expressions allowed',
                               self.stmt.test)
        if self.stmt.msg:
            return self._assert_reason(test_expr, self.stmt.msg)
        else:
            return LLLnode.from_list(['assert', test_expr],
                                     typ=None,
                                     pos=getpos(self.stmt))
Exemplo n.º 4
0
    def unroll_constant(self, const, global_ctx):
        ann_expr = None
        expr = Expr.parse_value_expr(
            const.value,
            Context(vars=None,
                    global_ctx=global_ctx,
                    origcode=const.full_source_code,
                    memory_allocator=MemoryAllocator()),
        )
        annotation_type = global_ctx.parse_type(const.annotation.args[0], None)
        fail = False

        if is_instances([expr.typ, annotation_type], ByteArrayType):
            if expr.typ.maxlen < annotation_type.maxlen:
                return const
            fail = True

        elif expr.typ != annotation_type:
            fail = True
            # special case for literals, which can be uint256 types as well.
            is_special_case_uint256_literal = (is_instances(
                [expr.typ, annotation_type], BaseType)) and ([
                    annotation_type.typ, expr.typ.typ
                ] == ['uint256', 'int128']) and SizeLimits.in_bounds(
                    'uint256', expr.value)

            is_special_case_int256_literal = (is_instances(
                [expr.typ, annotation_type], BaseType)) and ([
                    annotation_type.typ, expr.typ.typ
                ] == ['int128', 'int128']) and SizeLimits.in_bounds(
                    'int128', expr.value)

            if is_special_case_uint256_literal or is_special_case_int256_literal:
                fail = False

        if fail:
            raise TypeMismatch(
                f"Invalid value for constant type, expected {annotation_type} got "
                f"{expr.typ} instead",
                const.value,
            )

        ann_expr = copy.deepcopy(expr)
        ann_expr.typ = annotation_type
        ann_expr.typ.is_literal = expr.typ.is_literal  # Annotation type doesn't have literal set.

        return ann_expr
Exemplo n.º 5
0
    def parse_if(self):
        if self.stmt.orelse:
            block_scope_id = id(self.stmt.orelse)
            with self.context.make_blockscope(block_scope_id):
                add_on = [parse_body(self.stmt.orelse, self.context)]
        else:
            add_on = []

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

            if not self.is_bool_expr(test_expr):
                raise TypeMismatch('Only boolean expressions allowed',
                                   self.stmt.test)
            body = ['if', test_expr,
                    parse_body(self.stmt.body, self.context)] + add_on
            o = LLLnode.from_list(body, typ=None, pos=getpos(self.stmt))
        return o
Exemplo n.º 6
0
def make_external_call(stmt_expr, context):
    from srilang.parser.expr import Expr
    value, gas = get_external_contract_keywords(stmt_expr, context)

    if (isinstance(stmt_expr.func, sri_ast.Attribute)
            and isinstance(stmt_expr.func.value, sri_ast.Call)):
        contract_name = stmt_expr.func.value.func.id
        contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0],
                                                 context)

        return external_contract_call(
            stmt_expr,
            context,
            contract_name,
            contract_address,
            pos=getpos(stmt_expr),
            value=value,
            gas=gas,
        )

    elif isinstance(
            stmt_expr.func.value, sri_ast.Attribute
    ) and stmt_expr.func.value.attr in context.sigs:  # noqa: E501
        contract_name = stmt_expr.func.value.attr
        var = context.globals[stmt_expr.func.value.attr]
        contract_address = unwrap_location(
            LLLnode.from_list(
                var.pos,
                typ=var.typ,
                location='storage',
                pos=getpos(stmt_expr),
                annotation='self.' + stmt_expr.func.value.attr,
            ))

        return external_contract_call(
            stmt_expr,
            context,
            contract_name,
            contract_address,
            pos=getpos(stmt_expr),
            value=value,
            gas=gas,
        )

    elif (isinstance(stmt_expr.func.value, sri_ast.Attribute)
          and stmt_expr.func.value.attr in context.globals
          and hasattr(context.globals[stmt_expr.func.value.attr].typ, 'name')):

        contract_name = context.globals[stmt_expr.func.value.attr].typ.name
        var = context.globals[stmt_expr.func.value.attr]
        contract_address = unwrap_location(
            LLLnode.from_list(
                var.pos,
                typ=var.typ,
                location='storage',
                pos=getpos(stmt_expr),
                annotation='self.' + stmt_expr.func.value.attr,
            ))

        return external_contract_call(
            stmt_expr,
            context,
            contract_name,
            contract_address,
            pos=getpos(stmt_expr),
            value=value,
            gas=gas,
        )

    else:
        raise StructureException("Unsupported operator.", stmt_expr)
Exemplo n.º 7
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):
    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:
        if expected_arg == 'num_literal':
            if context.constants.is_constant_of_base_type(arg, ('uint256', 'int128')):
                return context.constants.get_constant(arg.id, None).value
            if isinstance(arg, (sri_ast.Int, sri_ast.Decimal)):
                return arg.n
        elif expected_arg == 'str_literal':
            if isinstance(arg, sri_ast.Str):
                bytez = b''
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteral(
                            f"Cannot insert special character {c} into byte array",
                            arg,
                        )
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == 'bytes_literal':
            if isinstance(arg, sri_ast.Bytes):
                return arg.s
        elif expected_arg == 'name_literal':
            if isinstance(arg, sri_ast.Name):
                return arg.id
            elif isinstance(arg, sri_ast.Subscript) and arg.value.id == 'bytes':
                return f'bytes[{arg.slice.value.n}]'
        elif expected_arg == '*':
            return arg
        elif expected_arg == 'bytes':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == 'string':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            # Does not work for unit-endowed types inside compound types, e.g. timestamp[2]
            parsed_expected_type = context.parse_type(
                sri_ast.parse_to_ast(expected_arg)[0].value,
                'memory',
            )
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)

                is_valid_integer = (
                    (
                        expected_arg in ('int128', 'uint256') and isinstance(vsub.typ, BaseType)
                    ) and (
                        vsub.typ.typ in ('int128', 'uint256') and vsub.typ.is_literal
                    ) and (
                        SizeLimits.in_bounds(expected_arg, vsub.value)
                    )
                )

                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif is_valid_integer:
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).lll_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).lll_node
    if len(expected_arg_typelist) == 1:
        raise TypeMismatch(
            f"Expecting {expected_arg} for argument {index} of {function_name}",
            arg
        )
    else:
        raise TypeMismatch(
            f"Expecting one of {expected_arg_typelist} for argument {index} of {function_name}",
            arg
        )
Exemplo n.º 8
0
 def aug_assign(self):
     target = self.get_target(self.stmt.target)
     sub = Expr.parse_value_expr(self.stmt.value, self.context)
     if not isinstance(self.stmt.op,
                       (sri_ast.Add, sri_ast.Sub, sri_ast.Mult, sri_ast.Div,
                        sri_ast.Mod)):
         raise StructureException("Unsupported operator for augassign",
                                  self.stmt)
     if not isinstance(target.typ, BaseType):
         raise TypeMismatch(
             "Can only use aug-assign operators with simple types!",
             self.stmt.target)
     if target.location == 'storage':
         o = Expr.parse_value_expr(
             sri_ast.BinOp(
                 left=LLLnode.from_list(['sload', '_stloc'],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list([
             'with',
             '_stloc',
             target,
             [
                 'sstore',
                 '_stloc',
                 base_type_conversion(
                     o, o.typ, target.typ, pos=getpos(self.stmt)),
             ],
         ],
                                  typ=None,
                                  pos=getpos(self.stmt))
     elif target.location == 'memory':
         o = Expr.parse_value_expr(
             sri_ast.BinOp(
                 left=LLLnode.from_list(['mload', '_mloc'],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list([
             'with',
             '_mloc',
             target,
             [
                 'mstore',
                 '_mloc',
                 base_type_conversion(
                     o, o.typ, target.typ, pos=getpos(self.stmt)),
             ],
         ],
                                  typ=None,
                                  pos=getpos(self.stmt))
Exemplo n.º 9
0
    def parse_for(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, sri_ast.Call):
            if isinstance(self.stmt.iter, sri_ast.Subscript):
                raise StructureException("Cannot iterate over a nested list",
                                         self.stmt.iter)
            raise StructureException(
                f"Cannot iterate over '{type(self.stmt.iter).__name__}' object",
                self.stmt.iter)
        if getattr(self.stmt.iter.func, 'id', None) != "range":
            raise StructureException(
                "Non-literals cannot be used as loop range",
                self.stmt.iter.func)
        if len(self.stmt.iter.args) not in {1, 2}:
            raise StructureException(
                f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}",
                self.stmt.iter.func)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, sri_ast.BinOp) or not isinstance(
                        arg1.op, sri_ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if not sri_ast.compare_nodes(arg0, arg1.left):
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {sri_ast.ast_to_dict(arg0)} {sri_ast.ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                raise StructureException(
                    f"For loop has invalid number of iterations ({r}),"
                    " the value must be greater than zero", self.stmt.iter)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o