예제 #1
0
    def parse_assert(self):
        tmp = self.context.in_assertion  # backup value
        try:
            self.context.set_in_assertion(True)
            test_expr = Expr.parse_value_expr(self.stmt.test, self.context)
        finally:
            self.context.set_in_assertion(tmp)  # restore

        if not self.is_bool_expr(test_expr):
            raise TypeMismatchException('Only boolean expressions allowed',
                                        self.stmt.test)
        if self.stmt.msg:
            if not isinstance(self.stmt.msg, ast.Str):
                raise StructureException(
                    'Reason parameter of assert needs to be a literal string.',
                    self.stmt.msg)
            if len(self.stmt.msg.s.strip()) == 0:
                raise StructureException('Empty reason string not allowed.',
                                         self.stmt)
            reason_str = self.stmt.msg.s.strip()
            sig_placeholder = self.context.new_placeholder(BaseType(32))
            arg_placeholder = self.context.new_placeholder(BaseType(32))
            reason_str_type = ByteArrayType(len(reason_str))
            placeholder_bytes = Expr(self.stmt.msg, self.context).lll_node
            method_id = fourbytes_to_int(sha3(b"Error(string)")[:4])
            assert_reason = \
                ['seq',
                    ['mstore', sig_placeholder, method_id],
                    ['mstore', arg_placeholder, 32],
                    placeholder_bytes,
                    ['assert_reason', test_expr, int(sig_placeholder + 28), int(4 + 32 + get_size_of_type(reason_str_type) * 32)]]
            return LLLnode.from_list(assert_reason,
                                     typ=None,
                                     pos=getpos(self.stmt))
        else:
            return LLLnode.from_list(['assert', test_expr],
                                     typ=None,
                                     pos=getpos(self.stmt))
예제 #2
0
def raw_log(expr, args, kwargs, context):
    if not isinstance(args[0], ast.List) or len(args[0].elts) > 4:
        raise StructureException(
            "Expecting a list of 0-4 topics as first argument", args[0])
    topics = []
    for elt in args[0].elts:
        arg = Expr.parse_value_expr(elt, context)
        if not is_base_type(arg.typ, 'bytes32'):
            raise TypeMismatchException(
                "Expecting a bytes32 argument as topic", elt)
        topics.append(arg)
    if args[1].location == "memory":
        return LLLnode.from_list([
            "with", "_arr", args[1],
            ["log" + str(len(topics)), ["add", "_arr", 32], ["mload", "_arr"]]
            + topics
        ],
                                 typ=None,
                                 pos=getpos(expr))
    placeholder = context.new_placeholder(args[1].typ)
    placeholder_node = LLLnode.from_list(placeholder,
                                         typ=args[1].typ,
                                         location='memory')
    copier = make_byte_array_copier(
        placeholder_node,
        LLLnode.from_list('_sub', typ=args[1].typ, location=args[1].location))
    return LLLnode.from_list([
        "with", "_sub", args[1],
        [
            "seq", copier,
            [
                "log" + str(len(topics)), ["add", placeholder_node, 32],
                ["mload", placeholder_node]
            ] + topics
        ]
    ],
                             typ=None,
                             pos=getpos(expr))
예제 #3
0
    def unroll_constant(self, const):
        # const = self.context.constants[self.expr.id]
        expr = Expr.parse_value_expr(
            const.value,
            Context(vars=None, global_ctx=self, origcode=const.source_code))
        annotation_type = parse_type(const.annotation.args[0],
                                     None,
                                     custom_units=self._custom_units,
                                     custom_structs=self._structs)

        fail = False

        if self.is_instances([expr.typ, annotation_type], ByteArrayType):
            if expr.typ.maxlen < annotation_type.maxlen:
                return const
            fail = True

        elif expr.typ != annotation_type:
            fail = True
            # special case for literals, which can be uint256 types as well.
            if self.is_instances([expr.typ, annotation_type], BaseType) and \
               [annotation_type.typ, expr.typ.typ] == ['uint256', 'int128'] and \
               SizeLimits.in_bounds('uint256', expr.value):
                fail = False

            elif self.is_instances([expr.typ, annotation_type], BaseType) and \
               [annotation_type.typ, expr.typ.typ] == ['int128', 'int128'] and \
               SizeLimits.in_bounds('int128', expr.value):
                fail = False

        if fail:
            raise TypeMismatchException(
                'Invalid value for constant type, expected %r' %
                annotation_type, const.value)
        else:
            expr.typ = annotation_type
        return expr
예제 #4
0
파일: stmt.py 프로젝트: vilsbole/vyper
    def parse_if(self):
        from .parser import (
            parse_body, )

        if self.stmt.orelse:
            block_scope_id = id(self.stmt.orelse)
            with self.context.make_blockscope(block_scope_id):
                add_on = [['seq', parse_body(self.stmt.orelse, self.context)]]
        else:
            add_on = []

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

            if not self.is_bool_expr(test_expr):
                raise TypeMismatchException('Only boolean expressions allowed',
                                            self.stmt.test)
            body = [
                'if', test_expr,
                ['seq', parse_body(self.stmt.body, self.context)]
            ] + add_on
            o = LLLnode.from_list(body, typ=None, pos=getpos(self.stmt))
        return o
예제 #5
0
파일: stmt.py 프로젝트: olwee/vyper
    def parse_if(self):
        from .parser import (
            parse_body,
        )
        if self.stmt.orelse:
            block_scope_id = id(self.stmt.orelse)
            self.context.start_blockscope(block_scope_id)
            add_on = [parse_body(self.stmt.orelse, self.context)]
            self.context.end_blockscope(block_scope_id)
        else:
            add_on = []
        block_scope_id = id(self.stmt)
        self.context.start_blockscope(block_scope_id)
        test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

        if not self.is_bool_expr(test_expr):
            raise TypeMismatchException('Only boolean expressions allowed', self.stmt.test)

        o = LLLnode.from_list(
            ['if', test_expr, parse_body(self.stmt.body, self.context)] + add_on,
            typ=None, pos=getpos(self.stmt)
        )
        self.context.end_blockscope(block_scope_id)
        return o
예제 #6
0
파일: stmt.py 프로젝트: spencerbh/vyper
    def call(self):
        from .parser import (
            pack_logging_data,
            pack_logging_topics,
            external_contract_call,
        )
        if isinstance(self.stmt.func, ast.Name):
            if self.stmt.func.id in stmt_dispatch_table:
                return stmt_dispatch_table[self.stmt.func.id](self.stmt,
                                                              self.context)
            elif self.stmt.func.id in dispatch_table:
                raise StructureException(
                    "Function {} can not be called without being used.".format(
                        self.stmt.func.id), self.stmt)
            else:
                raise StructureException(
                    "Unknown function: '{}'.".format(self.stmt.func.id),
                    self.stmt)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(
                self.stmt.func.value,
                ast.Name) and self.stmt.func.value.id == "self":
            return self_call.make_call(self.stmt, self.context)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(
                self.stmt.func.value, ast.Call):
            contract_name = self.stmt.func.value.func.id
            contract_address = Expr.parse_value_expr(
                self.stmt.func.value.args[0], self.context)
            return external_contract_call(self.stmt,
                                          self.context,
                                          contract_name,
                                          contract_address,
                                          pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute
                        ) and self.stmt.func.value.attr in self.context.sigs:
            contract_name = self.stmt.func.value.attr
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(
                LLLnode.from_list(var.pos,
                                  typ=var.typ,
                                  location='storage',
                                  pos=getpos(self.stmt),
                                  annotation='self.' +
                                  self.stmt.func.value.attr))
            return external_contract_call(self.stmt,
                                          self.context,
                                          contract_name,
                                          contract_address,
                                          pos=getpos(self.stmt))
        elif isinstance(
                self.stmt.func.value, ast.Attribute
        ) and self.stmt.func.value.attr in self.context.globals:
            contract_name = self.context.globals[
                self.stmt.func.value.attr].typ.unit
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(
                LLLnode.from_list(var.pos,
                                  typ=var.typ,
                                  location='storage',
                                  pos=getpos(self.stmt),
                                  annotation='self.' +
                                  self.stmt.func.value.attr))
            return external_contract_call(self.stmt,
                                          self.context,
                                          contract_name,
                                          contract_address,
                                          pos=getpos(self.stmt))
        elif isinstance(self.stmt.func,
                        ast.Attribute) and self.stmt.func.value.id == 'log':
            if self.stmt.func.attr not in self.context.sigs['self']:
                raise EventDeclarationException("Event not declared yet: %s" %
                                                self.stmt.func.attr)
            event = self.context.sigs['self'][self.stmt.func.attr]
            if len(event.indexed_list) != len(self.stmt.args):
                raise EventDeclarationException(
                    "%s received %s arguments but expected %s" %
                    (event.name, len(self.stmt.args), len(event.indexed_list)))
            expected_topics, topics = [], []
            expected_data, data = [], []
            for pos, is_indexed in enumerate(event.indexed_list):
                if is_indexed:
                    expected_topics.append(event.args[pos])
                    topics.append(self.stmt.args[pos])
                else:
                    expected_data.append(event.args[pos])
                    data.append(self.stmt.args[pos])
            topics = pack_logging_topics(event.event_id,
                                         topics,
                                         expected_topics,
                                         self.context,
                                         pos=getpos(self.stmt))
            inargs, inargsize, inargsize_node, inarg_start = pack_logging_data(
                expected_data, data, self.context, pos=getpos(self.stmt))

            if inargsize_node is None:
                sz = inargsize
            else:
                sz = ['mload', inargsize_node]

            return LLLnode.from_list([
                'seq', inargs,
                LLLnode.from_list(
                    ["log" + str(len(topics)), inarg_start, sz] + topics,
                    add_gas_estimate=inargsize * 10)
            ],
                                     typ=None,
                                     pos=getpos(self.stmt))
        else:
            raise StructureException(
                "Unsupported operator: %r" % ast.dump(self.stmt), self.stmt)
예제 #7
0
def make_external_call(stmt_expr, context):
    from vyper.parser.expr import Expr
    value, gas = get_external_contract_keywords(stmt_expr, context)

    if isinstance(stmt_expr.func, ast.Attribute) and isinstance(
            stmt_expr.func.value, ast.Call):
        contract_name = stmt_expr.func.value.func.id
        contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0],
                                                 context)

        return external_contract_call(stmt_expr,
                                      context,
                                      contract_name,
                                      contract_address,
                                      pos=getpos(stmt_expr),
                                      value=value,
                                      gas=gas)

    elif isinstance(
            stmt_expr.func.value,
            ast.Attribute) and stmt_expr.func.value.attr in context.sigs:
        contract_name = stmt_expr.func.value.attr
        var = context.globals[stmt_expr.func.value.attr]
        contract_address = unwrap_location(
            LLLnode.from_list(var.pos,
                              typ=var.typ,
                              location='storage',
                              pos=getpos(stmt_expr),
                              annotation='self.' + stmt_expr.func.value.attr))

        return external_contract_call(stmt_expr,
                                      context,
                                      contract_name,
                                      contract_address,
                                      pos=getpos(stmt_expr),
                                      value=value,
                                      gas=gas)

    elif isinstance(
            stmt_expr.func.value,
            ast.Attribute) and stmt_expr.func.value.attr in context.globals:
        contract_name = context.globals[stmt_expr.func.value.attr].typ.unit
        var = context.globals[stmt_expr.func.value.attr]
        contract_address = unwrap_location(
            LLLnode.from_list(var.pos,
                              typ=var.typ,
                              location='storage',
                              pos=getpos(stmt_expr),
                              annotation='self.' + stmt_expr.func.value.attr))

        return external_contract_call(stmt_expr,
                                      context,
                                      contract_name,
                                      contract_address,
                                      pos=getpos(stmt_expr),
                                      value=value,
                                      gas=gas)

    else:
        raise StructureException(
            "Unsupported operator: %r" % ast.dump(stmt_expr), stmt_expr)
예제 #8
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):
    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:
        if expected_arg == 'num_literal':
            if context.constants.is_constant_of_base_type(
                    arg, ('uint256', 'int128')):
                return context.constants.get_constant(arg.id, None).value
            if isinstance(arg, ast.Num) and get_original_if_0_prefixed(
                    arg, context) is None:
                return arg.n
        elif expected_arg == 'str_literal':
            if isinstance(arg, ast.Str) and get_original_if_0_prefixed(
                    arg, context) is None:
                bytez = b''
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteralException(
                            "Cannot insert special character %r into byte array"
                            % c,
                            arg,
                        )
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == 'name_literal':
            if isinstance(arg, ast.Name):
                return arg.id
            elif isinstance(arg, ast.Subscript) and arg.value.id == 'bytes':
                return 'bytes[%s]' % arg.slice.value.n
        elif expected_arg == '*':
            return arg
        elif expected_arg == 'bytes':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == 'string':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            # Does not work for unit-endowed types inside compound types, e.g. timestamp[2]
            parsed_expected_type = context.parse_type(
                ast.parse(expected_arg).body[0].value,
                'memory',
            )
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)

                is_valid_integer = ((expected_arg in ('int128', 'uint256')
                                     and isinstance(vsub.typ, BaseType))
                                    and (vsub.typ.typ in ('int128', 'uint256')
                                         and vsub.typ.is_literal)
                                    and (SizeLimits.in_bounds(
                                        expected_arg, vsub.value)))

                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif is_valid_integer:
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).lll_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).lll_node
    if len(expected_arg_typelist) == 1:
        raise TypeMismatchException(
            "Expecting %s for argument %r of %s" %
            (expected_arg, index, function_name), arg)
    else:
        raise TypeMismatchException(
            "Expecting one of %r for argument %r of %s" %
            (expected_arg_typelist, index, function_name), arg)
예제 #9
0
파일: stmt.py 프로젝트: erdnaag/vyper
 def parse_AugAssign(self):
     target = self._get_target(self.stmt.target)
     sub = Expr.parse_value_expr(self.stmt.value, self.context)
     if not isinstance(target.typ, BaseType):
         return
     if target.location == "storage":
         lll_node = Expr.parse_value_expr(
             vy_ast.BinOp(
                 left=LLLnode.from_list(["sload", "_stloc"],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list(
             [
                 "with",
                 "_stloc",
                 target,
                 [
                     "sstore",
                     "_stloc",
                     base_type_conversion(lll_node,
                                          lll_node.typ,
                                          target.typ,
                                          pos=getpos(self.stmt)),
                 ],
             ],
             typ=None,
             pos=getpos(self.stmt),
         )
     elif target.location == "memory":
         lll_node = Expr.parse_value_expr(
             vy_ast.BinOp(
                 left=LLLnode.from_list(["mload", "_mloc"],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list(
             [
                 "with",
                 "_mloc",
                 target,
                 [
                     "mstore",
                     "_mloc",
                     base_type_conversion(lll_node,
                                          lll_node.typ,
                                          target.typ,
                                          pos=getpos(self.stmt)),
                 ],
             ],
             typ=None,
             pos=getpos(self.stmt),
         )
예제 #10
0
파일: stmt.py 프로젝트: agroce/vyper
 def aug_assign(self):
     target = self.get_target(self.stmt.target)
     sub = Expr.parse_value_expr(self.stmt.value, self.context)
     if not isinstance(self.stmt.op,
                       (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod)):
         raise StructureException("Unsupported operator for augassign",
                                  self.stmt)
     if not isinstance(target.typ, BaseType):
         raise TypeMismatchException(
             "Can only use aug-assign operators with simple types!",
             self.stmt.target)
     if target.location == 'storage':
         o = Expr.parse_value_expr(
             ast.BinOp(
                 left=LLLnode.from_list(['sload', '_stloc'],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list([
             'with',
             '_stloc',
             target,
             [
                 'sstore',
                 '_stloc',
                 base_type_conversion(
                     o, o.typ, target.typ, pos=getpos(self.stmt)),
             ],
         ],
                                  typ=None,
                                  pos=getpos(self.stmt))
     elif target.location == 'memory':
         o = Expr.parse_value_expr(
             ast.BinOp(
                 left=LLLnode.from_list(['mload', '_mloc'],
                                        typ=target.typ,
                                        pos=target.pos),
                 right=sub,
                 op=self.stmt.op,
                 lineno=self.stmt.lineno,
                 col_offset=self.stmt.col_offset,
                 end_lineno=self.stmt.end_lineno,
                 end_col_offset=self.stmt.end_col_offset,
             ),
             self.context,
         )
         return LLLnode.from_list([
             'with',
             '_mloc',
             target,
             [
                 'mstore',
                 '_mloc',
                 base_type_conversion(
                     o, o.typ, target.typ, pos=getpos(self.stmt)),
             ],
         ],
                                  typ=None,
                                  pos=getpos(self.stmt))
예제 #11
0
파일: stmt.py 프로젝트: zutobg/vyper
    def call(self):
        from .parser import (
            pack_arguments,
            pack_logging_data,
            pack_logging_topics,
            external_contract_call,
        )
        if isinstance(self.stmt.func, ast.Name):
            if self.stmt.func.id in stmt_dispatch_table:
                return stmt_dispatch_table[self.stmt.func.id](self.stmt, self.context)
            elif self.stmt.func.id in dispatch_table:
                raise StructureException("Function {} can not be called without being used.".format(self.stmt.func.id), self.stmt)
            else:
                raise StructureException("Unknown function: '{}'.".format(self.stmt.func.id), self.stmt)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Name) and self.stmt.func.value.id == "self":
            method_name = self.stmt.func.attr
            expr_args = [Expr(arg, self.context).lll_node for arg in self.stmt.args]
            # full_sig = FunctionSignature.get_full_sig(method_name, expr_args, self.context.sigs, self.context.custom_units)
            sig = FunctionSignature.lookup_sig(self.context.sigs, method_name, expr_args, self.stmt, self.context)
            if self.context.is_constant and not sig.const:
                raise ConstancyViolationException(
                    "May not call non-constant function '%s' within a constant function." % (sig.sig)
                )
            add_gas = self.context.sigs['self'][sig.sig].gas
            inargs, inargsize = pack_arguments(sig,
                                                expr_args,
                                                self.context, pos=getpos(self.stmt))
            return LLLnode.from_list(['assert', ['call', ['gas'], ['address'], 0, inargs, inargsize, 0, 0]],
                                        typ=None, pos=getpos(self.stmt), add_gas_estimate=add_gas, annotation='Internal Call: %s' % sig.sig)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Call):
            contract_name = self.stmt.func.value.func.id
            contract_address = Expr.parse_value_expr(self.stmt.func.value.args[0], self.context)
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.sigs:
            contract_name = self.stmt.func.value.attr
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr))
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.globals:
            contract_name = self.context.globals[self.stmt.func.value.attr].typ.unit
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr))
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func, ast.Attribute) and self.stmt.func.value.id == 'log':
            if self.stmt.func.attr not in self.context.sigs['self']:
                raise EventDeclarationException("Event not declared yet: %s" % self.stmt.func.attr)
            event = self.context.sigs['self'][self.stmt.func.attr]
            if len(event.indexed_list) != len(self.stmt.args):
                raise EventDeclarationException("%s received %s arguments but expected %s" % (event.name, len(self.stmt.args), len(event.indexed_list)))
            expected_topics, topics = [], []
            expected_data, data = [], []
            for pos, is_indexed in enumerate(event.indexed_list):
                if is_indexed:
                    expected_topics.append(event.args[pos])
                    topics.append(self.stmt.args[pos])
                else:
                    expected_data.append(event.args[pos])
                    data.append(self.stmt.args[pos])
            topics = pack_logging_topics(event.event_id, topics, expected_topics, self.context, pos=getpos(self.stmt))
            inargs, inargsize, inargsize_node, inarg_start = pack_logging_data(expected_data, data, self.context, pos=getpos(self.stmt))

            if inargsize_node is None:
                sz = inargsize
            else:
                sz = ['mload', inargsize_node]

            return LLLnode.from_list(['seq', inargs,
                LLLnode.from_list(["log" + str(len(topics)), inarg_start, sz] + topics, add_gas_estimate=inargsize * 10)], typ=None, pos=getpos(self.stmt))
        else:
            raise StructureException("Unsupported operator: %r" % ast.dump(self.stmt), self.stmt)
예제 #12
0
파일: stmt.py 프로젝트: 1holeshot/vyper
    def parse_for(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, vy_ast.Call):
            if isinstance(self.stmt.iter, vy_ast.Subscript):
                raise StructureException("Cannot iterate over a nested list",
                                         self.stmt.iter)
            raise StructureException(
                f"Cannot iterate over '{type(self.stmt.iter).__name__}' object",
                self.stmt.iter)
        if getattr(self.stmt.iter.func, 'id', None) != "range":
            raise StructureException(
                "Non-literals cannot be used as loop range",
                self.stmt.iter.func)
        if len(self.stmt.iter.args) not in {1, 2}:
            raise StructureException(
                f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}",
                self.stmt.iter.func)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, vy_ast.BinOp) or not isinstance(
                        arg1.op, vy_ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if arg0 != arg1.left:
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {vy_ast.ast_to_dict(arg0)} {vy_ast.ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                raise StructureException(
                    f"For loop has invalid number of iterations ({r}),"
                    " the value must be greater than zero", self.stmt.iter)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o
예제 #13
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):
    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:

        # temporary hack, once we refactor this package none of this will exist
        if hasattr(expected_arg, "_id"):
            expected_arg = expected_arg._id

        if expected_arg == "num_literal":
            if isinstance(arg, (vy_ast.Int, vy_ast.Decimal)):
                return arg.n
        elif expected_arg == "str_literal":
            if isinstance(arg, vy_ast.Str):
                bytez = b""
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteral(
                            f"Cannot insert special character {c} into byte array",
                            arg,
                        )
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == "bytes_literal":
            if isinstance(arg, vy_ast.Bytes):
                return arg.s
        elif expected_arg == "name_literal":
            if isinstance(arg, vy_ast.Name):
                return arg.id
            elif isinstance(arg, vy_ast.Subscript) and arg.value.id == "bytes":
                return f"bytes[{arg.slice.value.n}]"
        elif expected_arg == "*":
            return arg
        elif expected_arg == "bytes":
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == "string":
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            # Does not work for unit-endowed types inside compound types, e.g. timestamp[2]
            parsed_expected_type = context.parse_type(
                vy_ast.parse_to_ast(expected_arg)[0].value,
                "memory",
            )
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)

                is_valid_integer = ((expected_arg in ("int128", "uint256")
                                     and isinstance(vsub.typ, BaseType))
                                    and (vsub.typ.typ in ("int128", "uint256")
                                         and vsub.typ.is_literal)
                                    and (SizeLimits.in_bounds(
                                        expected_arg, vsub.value)))

                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif is_valid_integer:
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).lll_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).lll_node
    if len(expected_arg_typelist) == 1:
        raise TypeMismatch(
            f"Expecting {expected_arg} for argument {index} of {function_name}",
            arg)
    else:
        raise TypeMismatch(
            f"Expecting one of {expected_arg_typelist} for argument {index} of {function_name}",
            arg)
예제 #14
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):
    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:
        if expected_arg == 'num_literal':
            if context.constants.is_constant_of_base_type(arg, ('uint256', 'int128')):
                return context.constants.get_constant(arg.id, None).value
            if isinstance(arg, (vy_ast.Int, vy_ast.Decimal)):
                return arg.n
        elif expected_arg == 'str_literal':
            if isinstance(arg, vy_ast.Str):
                bytez = b''
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteral(
                            f"Cannot insert special character {c} into byte array",
                            arg,
                        )
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == 'bytes_literal':
            if isinstance(arg, vy_ast.Bytes):
                return arg.s
        elif expected_arg == 'name_literal':
            if isinstance(arg, vy_ast.Name):
                return arg.id
            elif isinstance(arg, vy_ast.Subscript) and arg.value.id == 'bytes':
                return f'bytes[{arg.slice.value.n}]'
        elif expected_arg == '*':
            return arg
        elif expected_arg == 'bytes':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == 'string':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            # Does not work for unit-endowed types inside compound types, e.g. timestamp[2]
            parsed_expected_type = context.parse_type(
                vy_ast.parse_to_ast(expected_arg)[0].value,
                'memory',
            )
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)

                is_valid_integer = (
                    (
                        expected_arg in ('int128', 'uint256') and isinstance(vsub.typ, BaseType)
                    ) and (
                        vsub.typ.typ in ('int128', 'uint256') and vsub.typ.is_literal
                    ) and (
                        SizeLimits.in_bounds(expected_arg, vsub.value)
                    )
                )

                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif is_valid_integer:
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).lll_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).lll_node
    if len(expected_arg_typelist) == 1:
        raise TypeMismatch(
            f"Expecting {expected_arg} for argument {index} of {function_name}",
            arg
        )
    else:
        raise TypeMismatch(
            f"Expecting one of {expected_arg_typelist} for argument {index} of {function_name}",
            arg
        )
예제 #15
0
    def parse_for(self):
        from .parser import (
            parse_body, )
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, ast.Call) or \
            not isinstance(self.stmt.iter.func, ast.Name) or \
                not isinstance(self.stmt.target, ast.Name) or \
                    self.stmt.iter.func.id != "range" or \
                        len(self.stmt.iter.args) not in (1, 2):
            raise StructureException(
                "For statements must be of the form `for i in range(rounds): ..` or `for i in range(start, start + rounds): ..`",
                self.stmt.iter)  # noqa

        block_scope_id = id(self.stmt.orelse)
        self.context.start_blockscope(block_scope_id)
        # Type 1 for, e.g. for i in range(10): ...
        if len(self.stmt.iter.args) == 1:
            if not isinstance(self.stmt.iter.args[0], ast.Num):
                raise StructureException("Range only accepts literal values",
                                         self.stmt.iter)
            start = LLLnode.from_list(0, typ='int128', pos=getpos(self.stmt))
            rounds = self.stmt.iter.args[0].n
        elif isinstance(self.stmt.iter.args[0], ast.Num) and isinstance(
                self.stmt.iter.args[1], ast.Num):
            # Type 2 for, e.g. for i in range(100, 110): ...
            start = LLLnode.from_list(self.stmt.iter.args[0].n,
                                      typ='int128',
                                      pos=getpos(self.stmt))
            rounds = LLLnode.from_list(self.stmt.iter.args[1].n -
                                       self.stmt.iter.args[0].n,
                                       typ='int128',
                                       pos=getpos(self.stmt))
        else:
            # Type 3 for, e.g. for i in range(x, x + 10): ...
            if not isinstance(self.stmt.iter.args[1],
                              ast.BinOp) or not isinstance(
                                  self.stmt.iter.args[1].op, ast.Add):
                raise StructureException(
                    "Two-arg for statements must be of the form `for i in range(start, start + rounds): ...`",
                    self.stmt.iter.args[1])
            if ast.dump(self.stmt.iter.args[0]) != ast.dump(
                    self.stmt.iter.args[1].left):
                raise StructureException(
                    "Two-arg for statements of the form `for i in range(x, x + y): ...` must have x identical in both places: %r %r"
                    % (ast.dump(self.stmt.iter.args[0]),
                       ast.dump(self.stmt.iter.args[1].left)), self.stmt.iter)
            if not isinstance(self.stmt.iter.args[1].right, ast.Num):
                raise StructureException("Range only accepts literal values",
                                         self.stmt.iter.args[1])
            start = Expr.parse_value_expr(self.stmt.iter.args[0], self.context)
            rounds = self.stmt.iter.args[1].right.n
        varname = self.stmt.target.id
        pos = self.context.new_variable(varname, BaseType('int128'))
        self.context.forvars[varname] = True
        o = LLLnode.from_list([
            'repeat', pos, start, rounds,
            parse_body(self.stmt.body, self.context)
        ],
                              typ=None,
                              pos=getpos(self.stmt))
        del self.context.vars[varname]
        del self.context.forvars[varname]
        self.context.end_blockscope(block_scope_id)
        return o
예제 #16
0
파일: stmt.py 프로젝트: spencerbh/vyper
    def parse_for(self):
        from .parser import (
            parse_body, )
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, ast.Call) or \
            not isinstance(self.stmt.iter.func, ast.Name) or \
                not isinstance(self.stmt.target, ast.Name) or \
                    self.stmt.iter.func.id != "range" or \
                        len(self.stmt.iter.args) not in (1, 2):
            raise StructureException(
                "For statements must be of the form `for i in range(rounds): ..` or `for i in range(start, start + rounds): ..`",
                self.stmt.iter)  # noqa

        block_scope_id = id(self.stmt.orelse)
        self.context.start_blockscope(block_scope_id)

        # Get arg0
        arg0 = self.stmt.iter.args[0]
        num_of_args = len(self.stmt.iter.args)

        # Type 1 for, e.g. for i in range(10): ...
        if num_of_args == 1:
            arg0_val = self._get_range_const_value(arg0)
            start = LLLnode.from_list(0, typ='int128', pos=getpos(self.stmt))
            rounds = arg0_val

        # Type 2 for, e.g. for i in range(100, 110): ...
        elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                              raise_exception=False):
            arg0_val = self._get_range_const_value(arg0)
            arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
            start = LLLnode.from_list(arg0_val,
                                      typ='int128',
                                      pos=getpos(self.stmt))
            rounds = LLLnode.from_list(arg1_val - arg0_val,
                                       typ='int128',
                                       pos=getpos(self.stmt))

        # Type 3 for, e.g. for i in range(x, x + 10): ...
        else:
            arg1 = self.stmt.iter.args[1]
            if not isinstance(arg1, ast.BinOp) or not isinstance(
                    arg1.op, ast.Add):
                raise StructureException(
                    "Two-arg for statements must be of the form `for i in range(start, start + rounds): ...`",
                    arg1)

            if ast.dump(arg0) != ast.dump(arg1.left):
                raise StructureException((
                    "Two-arg for statements of the form `for i in range(x, x + y): ...`"
                    " must have x identical in both places: %r %r") %
                                         (ast.dump(arg0), ast.dump(arg1.left)),
                                         self.stmt.iter)

            rounds = self._get_range_const_value(arg1.right)
            start = Expr.parse_value_expr(arg0, self.context)

        varname = self.stmt.target.id
        pos = self.context.new_variable(varname, BaseType('int128'))
        self.context.forvars[varname] = True
        o = LLLnode.from_list([
            'repeat', pos, start, rounds,
            parse_body(self.stmt.body, self.context)
        ],
                              typ=None,
                              pos=getpos(self.stmt))
        del self.context.vars[varname]
        del self.context.forvars[varname]
        self.context.end_blockscope(block_scope_id)
        return o
예제 #17
0
파일: stmt.py 프로젝트: agroce/vyper
    def parse_for(self):
        # from .parser import (
        #     parse_body,
        # )
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        is_invalid_for_statement = any((
            not isinstance(self.stmt.iter, ast.Call),
            not isinstance(self.stmt.iter.func, ast.Name),
            not isinstance(self.stmt.target, ast.Name),
            self.stmt.iter.func.id != "range",
            len(self.stmt.iter.args) not in {1, 2},
        ))
        if is_invalid_for_statement:
            raise StructureException(
                ("For statements must be of the form `for i in range(rounds): "
                 "..` or `for i in range(start, start + rounds): ..`"),
                self.stmt.iter)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, ast.BinOp) or not isinstance(
                        arg1.op, ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if arg0 != arg1.left:
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {ast_to_dict(arg0)} {ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o
예제 #18
0
파일: stmt.py 프로젝트: zutobg/vyper
 def parse_assert(self):
     return LLLnode.from_list(['assert', Expr.parse_value_expr(self.stmt.test, self.context)], typ=None, pos=getpos(self.stmt))
예제 #19
0
def make_external_call(stmt_expr, context):
    from vyper.parser.expr import Expr

    value, gas = get_external_interface_keywords(stmt_expr, context)

    if isinstance(stmt_expr.func, vy_ast.Attribute) and isinstance(
            stmt_expr.func.value, vy_ast.Call):
        contract_name = stmt_expr.func.value.func.id
        contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0],
                                                 context)

        return external_call(
            stmt_expr,
            context,
            contract_name,
            contract_address,
            pos=getpos(stmt_expr),
            value=value,
            gas=gas,
        )

    elif (isinstance(stmt_expr.func.value, vy_ast.Attribute)
          and stmt_expr.func.value.attr in context.sigs):  # noqa: E501
        contract_name = stmt_expr.func.value.attr
        var = context.globals[stmt_expr.func.value.attr]
        contract_address = unwrap_location(
            LLLnode.from_list(
                var.pos,
                typ=var.typ,
                location="storage",
                pos=getpos(stmt_expr),
                annotation="self." + stmt_expr.func.value.attr,
            ))

        return external_call(
            stmt_expr,
            context,
            contract_name,
            contract_address,
            pos=getpos(stmt_expr),
            value=value,
            gas=gas,
        )

    elif (isinstance(stmt_expr.func.value, vy_ast.Attribute)
          and stmt_expr.func.value.attr in context.globals
          and hasattr(context.globals[stmt_expr.func.value.attr].typ, "name")):

        contract_name = context.globals[stmt_expr.func.value.attr].typ.name
        var = context.globals[stmt_expr.func.value.attr]
        contract_address = unwrap_location(
            LLLnode.from_list(
                var.pos,
                typ=var.typ,
                location="storage",
                pos=getpos(stmt_expr),
                annotation="self." + stmt_expr.func.value.attr,
            ))

        return external_call(
            stmt_expr,
            context,
            contract_name,
            contract_address,
            pos=getpos(stmt_expr),
            value=value,
            gas=gas,
        )

    else:
        raise StructureException("Unsupported operator.", stmt_expr)
예제 #20
0
파일: stmt.py 프로젝트: erdnaag/vyper
    def parse_For(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self._parse_for_list()

        if self.stmt.get("iter.func.id") != "range":
            return
        if not 0 < len(self.stmt.iter.args) < 3:
            return

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ="int128",
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ="int128",
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ="int128",
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                return

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType("int128"),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            lll_node = LLLnode.from_list(
                [
                    "repeat", pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return lll_node