Ejemplo n.º 1
0
    def get_global_context(cls, code):
        global_ctx = cls()

        for item in code:
            # Contract references
            if isinstance(item, ast.ClassDef):
                # TDOO: remove events
                # if global_ctx._events or global_ctx._globals or global_ctx._defs:
                if global_ctx._globals or global_ctx._defs:
                    # raise StructureException("External contract declarations must come before event declarations, global declarations, and function definitions", item)
                    raise StructureException("External contract declarations must come before global declarations and function definitions", item)
                global_ctx._contracts[item.name] = global_ctx.add_contract(item.body)
            # Statements of the form:
            # variable_name: type
            elif isinstance(item, ast.AnnAssign):
                # TODO: remove events
                # global_ctx.add_globals_and_events(item)
                global_ctx.add_globals(item)
            # Function definitions
            elif isinstance(item, ast.FunctionDef):
                if item.name in global_ctx._globals:
                    raise FunctionDeclarationException("Function name shadowing a variable name: %s" % item.name)
                global_ctx._defs.append(item)
            else:
                raise StructureException("Invalid top-level statement", item)
        # Add getters to _defs
        global_ctx._defs += global_ctx._getters
        return global_ctx
Ejemplo n.º 2
0
    def get_target(self, target):
        if isinstance(target, ast.Subscript) and self.context.in_for_loop:  # Check if we are doing assignment of an iteration loop.
            raise_exception = False
            if isinstance(target.value, ast.Attribute):
                list_name = "%s.%s" % (target.value.value.id, target.value.attr)
                if list_name in self.context.in_for_loop:
                    raise_exception = True

            if isinstance(target.value, ast.Name) and \
               target.value.id in self.context.in_for_loop:
                list_name = target.value.id
                raise_exception = True

            if raise_exception:
                raise StructureException("Altering list '%s' which is being iterated!" % list_name, self.stmt)

        if isinstance(target, ast.Name) and target.id in self.context.forvars:
            raise StructureException("Altering iterator '%s' which is in use!" % target.id, self.stmt)
        if isinstance(target, ast.Tuple):
            return Expr(target, self.context).lll_node
        target = Expr.parse_variable_location(target, self.context)
        if target.location == 'storage' and self.context.is_constant:
            raise ConstancyViolationException("Cannot modify storage inside a constant function: %s" % target.annotation)
        if not target.mutable:
            raise ConstancyViolationException("Cannot modify function argument: %s" % target.annotation)
        return target
Ejemplo n.º 3
0
    def call(self):
        from .parser import (
            pack_logging_data,
            pack_logging_topics,
            external_contract_call,
        )
        if isinstance(self.stmt.func, ast.Name):
            if self.stmt.func.id in stmt_dispatch_table:
                return stmt_dispatch_table[self.stmt.func.id](self.stmt, self.context)
            elif self.stmt.func.id in dispatch_table:
                raise StructureException("Function {} can not be called without being used.".format(self.stmt.func.id), self.stmt)
            else:
                raise StructureException("Unknown function: '{}'.".format(self.stmt.func.id), self.stmt)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Name) and self.stmt.func.value.id == "self":
            return self_call.make_call(self.stmt, self.context)
        elif isinstance(self.stmt.func, ast.Attribute) and isinstance(self.stmt.func.value, ast.Call):
            contract_name = self.stmt.func.value.func.id
            contract_address = Expr.parse_value_expr(self.stmt.func.value.args[0], self.context)
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.sigs:
            contract_name = self.stmt.func.value.attr
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr))
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func.value, ast.Attribute) and self.stmt.func.value.attr in self.context.globals:
            contract_name = self.context.globals[self.stmt.func.value.attr].typ.unit
            var = self.context.globals[self.stmt.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.stmt), annotation='self.' + self.stmt.func.value.attr))
            return external_contract_call(self.stmt, self.context, contract_name, contract_address, pos=getpos(self.stmt))
        elif isinstance(self.stmt.func, ast.Attribute) and self.stmt.func.value.id == 'log':
            if self.stmt.func.attr not in self.context.sigs['self']:
                raise EventDeclarationException("Event not declared yet: %s" % self.stmt.func.attr)
            event = self.context.sigs['self'][self.stmt.func.attr]
            if len(event.indexed_list) != len(self.stmt.args):
                raise EventDeclarationException("%s received %s arguments but expected %s" % (event.name, len(self.stmt.args), len(event.indexed_list)))
            expected_topics, topics = [], []
            expected_data, data = [], []
            for pos, is_indexed in enumerate(event.indexed_list):
                if is_indexed:
                    expected_topics.append(event.args[pos])
                    topics.append(self.stmt.args[pos])
                else:
                    expected_data.append(event.args[pos])
                    data.append(self.stmt.args[pos])
            topics = pack_logging_topics(event.event_id, topics, expected_topics, self.context, pos=getpos(self.stmt))
            inargs, inargsize, inargsize_node, inarg_start = pack_logging_data(expected_data, data, self.context, pos=getpos(self.stmt))

            if inargsize_node is None:
                sz = inargsize
            else:
                sz = ['mload', inargsize_node]

            return LLLnode.from_list(['seq', inargs,
                LLLnode.from_list(["log" + str(len(topics)), inarg_start, sz] + topics, add_gas_estimate=inargsize * 10)], typ=None, pos=getpos(self.stmt))
        else:
            raise StructureException("Unsupported operator: %r" % ast.dump(self.stmt), self.stmt)
Ejemplo n.º 4
0
 def add_constant(self, item):
     args = item.annotation.args
     if not item.value:
         raise StructureException('Constants must express a value!', item)
     if len(args) == 1 and isinstance(args[0], (ast.Subscript, ast.Name, ast.Call)) and item.target:
         c_name = item.target.id
         if self.is_valid_varname(c_name, item):
             self._constants[c_name] = self.unroll_constant(item)
     else:
         raise StructureException('Incorrectly formatted struct', item)
Ejemplo n.º 5
0
    def parse_delete(self):
        from .parser import (
            make_setter,
        )
        if len(self.stmt.targets) != 1:
            raise StructureException("Can delete one variable at a time", self.stmt)
        target = self.stmt.targets[0]
        target_lll = Expr(self.stmt.targets[0], self.context).lll_node

        if isinstance(target, ast.Subscript):
            if target_lll.location == "storage":
                return make_setter(target_lll, LLLnode.from_list(None, typ=NullType()), "storage", pos=getpos(self.stmt))

        raise StructureException("Deleting type not supported.", self.stmt)
Ejemplo n.º 6
0
 def parse_name(self):
     if self.stmt.id == "vdb":
         return LLLnode('debugger', typ=None, pos=getpos(self.stmt))
     elif self.stmt.id == "throw":
         return LLLnode.from_list(['assert', 0], typ=None, pos=getpos(self.stmt))
     else:
         raise StructureException("Unsupported statement type: %s" % type(self.stmt), self.stmt)
Ejemplo n.º 7
0
 def __init__(self, stmt, context):
     self.stmt = stmt
     self.context = context
     self.stmt_table = {
         ast.Expr: self.expr,
         ast.Pass: self.parse_pass,
         ast.AnnAssign: self.ann_assign,
         ast.Assign: self.assign,
         ast.If: self.parse_if,
         ast.Call: self.call,
         ast.Assert: self.parse_assert,
         ast.For: self.parse_for,
         ast.AugAssign: self.aug_assign,
         ast.Break: self.parse_break,
         ast.Continue: self.parse_continue,
         ast.Return: self.parse_return,
         ast.Delete: self.parse_delete,
         ast.Str: self.parse_docblock,  # docblock
         ast.Name: self.parse_name,
     }
     stmt_type = self.stmt.__class__
     if stmt_type in self.stmt_table:
         self.lll_node = self.stmt_table[stmt_type]()
     else:
         raise StructureException("Unsupported statement type: %s" % type(stmt), stmt)
Ejemplo n.º 8
0
 def tuple_literals(self):
     if not len(self.expr.elts):
         raise StructureException("Tuple must have elements", self.expr)
     o = []
     for elt in self.expr.elts:
         o.append(Expr(elt, self.context).lll_node)
     return LLLnode.from_list(["multi"] + o, typ=TupleType(o), pos=getpos(self.expr))
Ejemplo n.º 9
0
def parse_external_contracts(external_contracts, _contracts):
    for _contractname in _contracts:
        _contract_defs = _contracts[_contractname]
        _defnames = [_def.name for _def in _contract_defs]
        contract = {}
        if len(set(_defnames)) < len(_contract_defs):
            raise FunctionDeclarationException(
                "Duplicate function name: %s" %
                [name for name in _defnames if _defnames.count(name) > 1][0])

        for _def in _contract_defs:
            constant = False
            # test for valid call type keyword.
            if len(_def.body) == 1 and \
               isinstance(_def.body[0], ast.Expr) and \
               isinstance(_def.body[0].value, ast.Name) and \
               _def.body[0].value.id in ('modifying', 'constant'):
                constant = True if _def.body[
                    0].value.id == 'constant' else False
            else:
                raise StructureException(
                    'constant or modifying call type must be specified', _def)
            sig = FunctionSignature.from_definition(_def,
                                                    contract_def=True,
                                                    constant=constant)
            contract[sig.name] = sig
        external_contracts[_contractname] = contract
    return external_contracts
Ejemplo n.º 10
0
    def boolean_operations(self):
        if len(self.expr.values) != 2:
            raise StructureException("Expected two arguments for a bool op", self.expr)
        if self.context.in_assignment and (isinstance(self.expr.values[0], ast.Call) or isinstance(self.expr.values[1], ast.Call)):
            raise StructureException("Boolean operations with calls may not be performed on assignment", self.expr)

        left = Expr.parse_value_expr(self.expr.values[0], self.context)
        right = Expr.parse_value_expr(self.expr.values[1], self.context)
        if not is_base_type(left.typ, 'bool') or not is_base_type(right.typ, 'bool'):
            raise TypeMismatchException("Boolean operations can only be between booleans!", self.expr)
        if isinstance(self.expr.op, ast.And):
            op = 'and'
        elif isinstance(self.expr.op, ast.Or):
            op = 'or'
        else:
            raise Exception("Unsupported bool op: " + self.expr.op)
        return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))
Ejemplo n.º 11
0
    def parse_for(self):
        from .parser import (
            parse_body,
        )
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, ast.Call) or \
            not isinstance(self.stmt.iter.func, ast.Name) or \
                not isinstance(self.stmt.target, ast.Name) or \
                    self.stmt.iter.func.id != "range" or \
                        len(self.stmt.iter.args) not in (1, 2):
            raise StructureException("For statements must be of the form `for i in range(rounds): ..` or `for i in range(start, start + rounds): ..`", self.stmt.iter)  # noqa

        block_scope_id = id(self.stmt.orelse)
        self.context.start_blockscope(block_scope_id)
        # Type 1 for, e.g. for i in range(10): ...
        if len(self.stmt.iter.args) == 1:
            if not isinstance(self.stmt.iter.args[0], ast.Num):
                raise StructureException("Range only accepts literal values", self.stmt.iter)
            start = LLLnode.from_list(0, typ='int128', pos=getpos(self.stmt))
            rounds = self.stmt.iter.args[0].n
        elif isinstance(self.stmt.iter.args[0], ast.Num) and isinstance(self.stmt.iter.args[1], ast.Num):
            # Type 2 for, e.g. for i in range(100, 110): ...
            start = LLLnode.from_list(self.stmt.iter.args[0].n, typ='int128', pos=getpos(self.stmt))
            rounds = LLLnode.from_list(self.stmt.iter.args[1].n - self.stmt.iter.args[0].n, typ='int128', pos=getpos(self.stmt))
        else:
            # Type 3 for, e.g. for i in range(x, x + 10): ...
            if not isinstance(self.stmt.iter.args[1], ast.BinOp) or not isinstance(self.stmt.iter.args[1].op, ast.Add):
                raise StructureException("Two-arg for statements must be of the form `for i in range(start, start + rounds): ...`",
                                            self.stmt.iter.args[1])
            if ast.dump(self.stmt.iter.args[0]) != ast.dump(self.stmt.iter.args[1].left):
                raise StructureException("Two-arg for statements of the form `for i in range(x, x + y): ...` must have x identical in both places: %r %r" % (ast.dump(self.stmt.iter.args[0]), ast.dump(self.stmt.iter.args[1].left)), self.stmt.iter)
            if not isinstance(self.stmt.iter.args[1].right, ast.Num):
                raise StructureException("Range only accepts literal values", self.stmt.iter.args[1])
            start = Expr.parse_value_expr(self.stmt.iter.args[0], self.context)
            rounds = self.stmt.iter.args[1].right.n
        varname = self.stmt.target.id
        pos = self.context.new_variable(varname, BaseType('int128'))
        self.context.forvars[varname] = True
        o = LLLnode.from_list(['repeat', pos, start, rounds, parse_body(self.stmt.body, self.context)], typ=None, pos=getpos(self.stmt))
        del self.context.vars[varname]
        del self.context.forvars[varname]
        self.context.end_blockscope(block_scope_id)
        return o
Ejemplo n.º 12
0
 def add_contract(code):
     _defs = []
     for item in code:
         # Function definitions
         if isinstance(item, ast.FunctionDef):
             _defs.append(item)
         else:
             raise StructureException("Invalid contract reference", item)
     return _defs
Ejemplo n.º 13
0
    def call(self):
        from ophydia.parser.parser import (
            external_contract_call
        )
        from ophydia.functions import (
            dispatch_table,
        )

        if isinstance(self.expr.func, ast.Name):
            function_name = self.expr.func.id
            if function_name in dispatch_table:
                return dispatch_table[function_name](self.expr, self.context)
            else:
                err_msg = "Not a top-level function: {}".format(function_name)
                if function_name in [x.split('(')[0] for x, _ in self.context.sigs['self'].items()]:
                    err_msg += ". Did you mean self.{}?".format(function_name)
                raise StructureException(err_msg, self.expr)
        elif isinstance(self.expr.func, ast.Attribute) and isinstance(self.expr.func.value, ast.Name) and self.expr.func.value.id == "self":
            return self_call.make_call(self.expr, self.context)
        elif isinstance(self.expr.func, ast.Attribute) and isinstance(self.expr.func.value, ast.Call):
            contract_name = self.expr.func.value.func.id
            contract_address = Expr.parse_value_expr(self.expr.func.value.args[0], self.context)
            value, gas = self._get_external_contract_keywords()
            return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas)
        elif isinstance(self.expr.func.value, ast.Attribute) and self.expr.func.value.attr in self.context.sigs:
            contract_name = self.expr.func.value.attr
            var = self.context.globals[self.expr.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.expr), annotation='self.' + self.expr.func.value.attr))
            value, gas = self._get_external_contract_keywords()
            return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas)
        elif isinstance(self.expr.func.value, ast.Attribute) and self.expr.func.value.attr in self.context.globals:
            contract_name = self.context.globals[self.expr.func.value.attr].typ.unit
            var = self.context.globals[self.expr.func.value.attr]
            contract_address = unwrap_location(LLLnode.from_list(var.pos, typ=var.typ, location='storage', pos=getpos(self.expr), annotation='self.' + self.expr.func.value.attr))
            value, gas = self._get_external_contract_keywords()
            return external_contract_call(self.expr, self.context, contract_name, contract_address, pos=getpos(self.expr), value=value, gas=gas)
        else:
            raise StructureException("Unsupported operator: %r" % ast.dump(self.expr), self.expr)
Ejemplo n.º 14
0
def pre_parse(code):
    result = []

    try:
        g = tokenize(io.BytesIO(code.encode('utf-8')).readline)
        for token in g:
            # Alias contract definition to class definition.
            if token.type == COMMENT and "@version" in token.string:
                parse_version_pragma(token.string[1:])
            if (token.type, token.string, token.start[1]) == (NAME, "contract",
                                                              0):
                token = TokenInfo(token.type, "class", token.start, token.end,
                                  token.line)
            # Prevent semi-colon line statements.
            elif (token.type, token.string) == (OP, ";"):
                raise StructureException("Semi-colon statements not allowed.",
                                         token.start)

            result.append(token)
    except TokenError as e:
        raise StructureException(e.args[0], e.args[1]) from e

    return untokenize(result).decode('utf-8')
Ejemplo n.º 15
0
 def list_literals(self):
     if not len(self.expr.elts):
         raise StructureException("List must have elements", self.expr)
     o = []
     out_type = None
     for elt in self.expr.elts:
         o.append(Expr(elt, self.context).lll_node)
         if not out_type:
             out_type = o[-1].typ
         previous_type = o[-1].typ.subtype.typ if hasattr(o[-1].typ, 'subtype') else o[-1].typ
         current_type = out_type.subtype.typ if hasattr(out_type, 'subtype') else out_type
         if len(o) > 1 and previous_type != current_type:
             raise TypeMismatchException("Lists may only contain one type", self.expr)
     return LLLnode.from_list(["multi"] + o, typ=ListType(out_type, len(o)), pos=getpos(self.expr))
Ejemplo n.º 16
0
 def subscript(self):
     sub = Expr.parse_variable_location(self.expr.value, self.context)
     if isinstance(sub.typ, (MappingType, ListType)):
         if 'value' not in vars(self.expr.slice):
             raise StructureException("Array access must access a single element, not a slice", self.expr)
         index = Expr.parse_value_expr(self.expr.slice.value, self.context)
     elif isinstance(sub.typ, TupleType):
         if not isinstance(self.expr.slice.value, ast.Num) or self.expr.slice.value.n < 0 or self.expr.slice.value.n >= len(sub.typ.members):
             raise TypeMismatchException("Tuple index invalid", self.expr.slice.value)
         index = self.expr.slice.value.n
     else:
         raise TypeMismatchException("Bad subscript attempt", self.expr.value)
     o = add_variable_offset(sub, index, pos=getpos(self.expr))
     o.mutable = sub.mutable
     return o
Ejemplo n.º 17
0
 def g(element, context):
     function_name = element.func.id
     if len(element.args) > len(argz):
         raise StructureException(
             "Expected %d arguments for %s, got %d" %
             (len(argz), function_name, len(element.args)), element)
     subs = []
     for i, expected_arg in enumerate(argz):
         if len(element.args) > i:
             subs.append(
                 process_arg(i + 1, element.args[i], expected_arg,
                             function_name, context))
         elif isinstance(expected_arg, Optional):
             subs.append(expected_arg.default)
         else:
             raise StructureException(
                 "Not enough arguments for function: {}".format(
                     element.func.id), element)
     kwsubs = {}
     element_kw = {k.arg: k.value for k in element.keywords}
     for k, expected_arg in kwargz.items():
         if k not in element_kw:
             if isinstance(expected_arg, Optional):
                 kwsubs[k] = expected_arg.default
             else:
                 raise StructureException(
                     "Function %s requires argument %s" %
                     (function_name, k), element)
         else:
             kwsubs[k] = process_arg(k, element_kw[k], expected_arg,
                                     function_name, context)
     for k, arg in element_kw.items():
         if k not in kwargz:
             raise StructureException("Unexpected argument: %s" % k,
                                      element)
     return f(element, subs, kwsubs, context)
Ejemplo n.º 18
0
 def get_item_name_and_attributes(self, item, attributes):
     if isinstance(item, ast.Name):
         return item.id, attributes
     elif isinstance(item, ast.AnnAssign):
         return self.get_item_name_and_attributes(item.annotation, attributes)
     elif isinstance(item, ast.Subscript):
         return self.get_item_name_and_attributes(item.value, attributes)
     # elif ist
     elif isinstance(item, ast.Call):
         attributes[item.func.id] = True
         # Raise for multiple args
         if len(item.args) != 1:
             raise StructureException("%s expects one arg (the type)" % item.func.id)
         return self.get_item_name_and_attributes(item.args[0], attributes)
     return None, attributes
Ejemplo n.º 19
0
def method_id(expr, args, kwargs, context):
    if b' ' in args[0]:
        raise TypeMismatchException(
            'Invalid function signature no spaces allowed.')
    method_id = fourbytes_to_int(sha3(args[0])[:4])
    if args[1] == 'bytes32':
        return LLLnode(method_id, typ=BaseType('bytes32'), pos=getpos(expr))
    elif args[1] == 'bytes[4]':
        placeholder = LLLnode.from_list(
            context.new_placeholder(ByteArrayType(4)))
        return LLLnode.from_list([
            'seq', ['mstore', ['add', placeholder, 4], method_id],
            ['mstore', placeholder, 4], placeholder
        ],
                                 typ=ByteArrayType(4),
                                 location='memory',
                                 pos=getpos(expr))
    else:
        raise StructureException(
            'Can only produce bytes32 or bytes[4] as outputs')
Ejemplo n.º 20
0
def raw_log(expr, args, kwargs, context):
    if not isinstance(args[0], ast.List) or len(args[0].elts) > 4:
        raise StructureException(
            "Expecting a list of 0-4 topics as first argument", args[0])
    topics = []
    for elt in args[0].elts:
        arg = Expr.parse_value_expr(elt, context)
        if not is_base_type(arg.typ, 'bytes32'):
            raise TypeMismatchException(
                "Expecting a bytes32 argument as topic", elt)
        topics.append(arg)
    if args[1].location == "memory":
        return LLLnode.from_list([
            "with", "_arr", args[1],
            ["log" + str(len(topics)), ["add", "_arr", 32], ["mload", "_arr"]]
            + topics
        ],
                                 typ=None,
                                 pos=getpos(expr))
    placeholder = context.new_placeholder(args[1].typ)
    placeholder_node = LLLnode.from_list(placeholder,
                                         typ=args[1].typ,
                                         location='memory')
    copier = make_byte_array_copier(placeholder_node,
                                    LLLnode.from_list(
                                        '_sub',
                                        typ=args[1].typ,
                                        location=args[1].location),
                                    pos=getpos(expr))
    return LLLnode.from_list([
        "with", "_sub", args[1],
        [
            "seq", copier,
            [
                "log" + str(len(topics)), ["add", placeholder_node, 32],
                ["mload", placeholder_node]
            ] + topics
        ]
    ],
                             typ=None,
                             pos=getpos(expr))
Ejemplo n.º 21
0
    def unary_operations(self):
        operand = Expr.parse_value_expr(self.expr.operand, self.context)
        if isinstance(self.expr.op, ast.Not):
            if isinstance(operand.typ, BaseType) and operand.typ.typ == 'bool':
                return LLLnode.from_list(["iszero", operand], typ='bool', pos=getpos(self.expr))
            else:
                raise TypeMismatchException("Only bool is supported for not operation, %r supplied." % operand.typ, self.expr)
        elif isinstance(self.expr.op, ast.USub):
            if not is_numeric_type(operand.typ):
                raise TypeMismatchException("Unsupported type for negation: %r" % operand.typ, operand)

            if operand.typ.is_literal and 'int' in operand.typ.typ:
                num = ast.Num(0 - operand.value)
                num.source_code = self.expr.source_code
                num.lineno = self.expr.lineno
                num.col_offset = self.expr.col_offset
                return Expr.parse_value_expr(num, self.context)

            return LLLnode.from_list(["sub", 0, operand], typ=operand.typ, pos=getpos(self.expr))
        else:
            raise StructureException("Only the 'not' unary operator is supported")
Ejemplo n.º 22
0
 def parse_assert(self):
     test_expr = Expr.parse_value_expr(self.stmt.test, self.context)
     if not self.is_bool_expr(test_expr):
         raise TypeMismatchException('Only boolean expressions allowed', self.stmt.test)
     if self.stmt.msg:
         if len(self.stmt.msg.s.strip()) == 0:
             raise StructureException('Empty reason string not allowed.', self.stmt)
         reason_str = self.stmt.msg.s.strip()
         sig_placeholder = self.context.new_placeholder(BaseType(32))
         arg_placeholder = self.context.new_placeholder(BaseType(32))
         reason_str_type = ByteArrayType(len(reason_str))
         placeholder_bytes = Expr(self.stmt.msg, self.context).lll_node
         method_id = fourbytes_to_int(sha3(b"Error(string)")[:4])
         assert_reason = \
             ['seq',
                 ['mstore', sig_placeholder, method_id],
                 ['mstore', arg_placeholder, 32],
                 placeholder_bytes,
                 ['assert_reason', test_expr, int(sig_placeholder + 28), int(4 + 32 + get_size_of_type(reason_str_type) * 32)]]
         return LLLnode.from_list(assert_reason, typ=None, pos=getpos(self.stmt))
     else:
         return LLLnode.from_list(['assert', test_expr], typ=None, pos=getpos(self.stmt))
Ejemplo n.º 23
0
 def assign(self):
     # Assignment (e.g. x[4] = y)
     if len(self.stmt.targets) != 1:
         raise StructureException("Assignment statement must have one target", self.stmt)
     self.context.set_in_assignment(True)
     sub = Expr(self.stmt.value, self.context).lll_node
     # Determine if it's an RLPList assignment.
     if isinstance(self.stmt.value, ast.Call) and getattr(self.stmt.value.func, 'id', '') is 'RLPList':
         pos = self.context.new_variable(self.stmt.targets[0].id, sub.typ)
         variable_loc = LLLnode.from_list(pos, typ=sub.typ, location='memory', pos=getpos(self.stmt), annotation=self.stmt.targets[0].id)
         o = make_setter(variable_loc, sub, 'memory', pos=getpos(self.stmt))
     # All other assignments are forbidden.
     elif isinstance(self.stmt.targets[0], ast.Name) and self.stmt.targets[0].id not in self.context.vars:
         raise VariableDeclarationException("Variable type not defined", self.stmt)
     elif isinstance(self.stmt.targets[0], ast.Tuple) and isinstance(self.stmt.value, ast.Tuple):
         raise VariableDeclarationException("Tuple to tuple assignment not supported", self.stmt)
     else:
         # Checks to see if assignment is valid
         target = self.get_target(self.stmt.targets[0])
         o = make_setter(target, sub, target.location, pos=getpos(self.stmt))
     o.pos = getpos(self.stmt)
     self.context.set_in_assignment(False)
     return o
Ejemplo n.º 24
0
    def add_globals(self, item):
        item_attributes = {"public": False}

        # Handle constants.
        if isinstance(item.annotation, ast.Call) and item.annotation.func.id == "constant":
            self.add_constant(item)
            return

        # Handle events.
        # TODO: events has been skipped
        if not (isinstance(item.annotation, ast.Call) and item.annotation.func.id == "event"):
            item_name, item_attributes = self.get_item_name_and_attributes(item, item_attributes)
            if not all([attr in valid_global_keywords for attr in item_attributes.keys()]):
                raise StructureException('Invalid global keyword used: %s' % item_attributes, item)

        if item.value is not None:
            raise StructureException('May not assign value whilst defining type', item)
        # TODO: remove events
        # elif isinstance(item.annotation, ast.Call) and item.annotation.func.id == "event":
        #     if self._globals or len(self._defs):
        #         raise EventDeclarationException("Events must all come before global declarations and function definitions", item)
        #     self._events.append(item)
        elif isinstance(item.annotation, ast.Call) and item.annotation.func.id == "event":
            raise EventDeclarationException("Events are not supported")
        elif not isinstance(item.target, ast.Name):
            raise StructureException("Can only assign type to variable in top-level statement", item)

        # TODO: custom unit definition
        # Is this a custom unit definition.
        elif item.target.id == 'units':
            if not self._custom_units:
                if not isinstance(item.annotation, ast.Dict):
                    raise VariableDeclarationException("Define custom units using units: { }.", item.target)
                for key, value in zip(item.annotation.keys, item.annotation.values):
                    if not isinstance(value, ast.Str):
                        raise VariableDeclarationException("Custom unit description must be a valid string", value)
                    if not isinstance(key, ast.Name):
                        raise VariableDeclarationException("Custom unit name must be a valid string", key)
                    if key.id in self._custom_units:
                        raise VariableDeclarationException("Custom unit name may only be used once", key)
                    if not is_varname_valid(key.id, custom_units=self._custom_units):
                        raise VariableDeclarationException("Custom unit may not be a reserved keyword", key)
                    self._custom_units.append(key.id)
            else:
                raise VariableDeclarationException("Custom units can only be defined once", item.target)

        # Check if variable name is valid.
        elif not self.is_valid_varname(item.target.id, item):
            pass

        elif len(self._defs):
            raise StructureException("Global variables must all come before function definitions", item)

        # If the type declaration is of the form public(<type here>), then proceed with
        # the underlying type but also add getters
        # TODO: support for premade_contract
        # elif isinstance(item.annotation, ast.Call) and item.annotation.func.id == "address":
        #     if item.annotation.args[0].id not in premade_contracts:
        #         raise VariableDeclarationException("Unsupported premade contract declaration", item.annotation.args[0])
        #     premade_contract = premade_contracts[item.annotation.args[0].id]
        #     self._contracts[item.target.id] = self.add_contract(premade_contract.body)
        #     self._globals[item.target.id] = VariableRecord(item.target.id, len(self._globals), BaseType('address'), True)

        # elif item_name in self._contracts:
        #     self._globals[item.target.id] = ContractRecord(item.target.id, len(self._globals), ContractType(item_name), True)
        #     if item_attributes["public"]:
        #         typ = ContractType(item_name)
        #         for getter in self.mk_getter(item.target.id, typ):
        #             self._getters.append(self.parse_line('\n' * (item.lineno - 1) + getter))
        #             self._getters[-1].pos = getpos(item)

        # TODO: supprt contract importing
        elif isinstance(item.annotation, ast.Call) and item.annotation.func.id == "public":
            if isinstance(item.annotation.args[0], ast.Name) and item_name in self._contracts:
                typ = ContractType(item_name)
            else:
                typ = parse_type(item.annotation.args[0], 'storage', custom_units=self._custom_units)
            self._globals[item.target.id] = VariableRecord(item.target.id, len(self._globals), typ, True)
            # Adding getters here
            for getter in self.mk_getter(item.target.id, typ):
                self._getters.append(self.parse_line('\n' * (item.lineno - 1) + getter))
                self._getters[-1].pos = getpos(item)

        else:
            self._globals[item.target.id] = VariableRecord(
                item.target.id, len(self._globals),
                parse_type(item.annotation, 'storage', custom_units=self._custom_units),
                True
            )
Ejemplo n.º 25
0
def concat(expr, context):
    args = [Expr(arg, context).lll_node for arg in expr.args]
    if len(args) < 2:
        raise StructureException("Concat expects at least two arguments", expr)
    for expr_arg, arg in zip(expr.args, args):
        if not isinstance(arg.typ, ByteArrayType) and not is_base_type(
                arg.typ, 'bytes32'):
            raise TypeMismatchException(
                "Concat expects byte arrays or bytes32 objects", expr_arg)
    # Maximum length of the output
    total_maxlen = sum([
        arg.typ.maxlen if isinstance(arg.typ, ByteArrayType) else 32
        for arg in args
    ])
    # Node representing the position of the output in memory
    placeholder = context.new_placeholder(ByteArrayType(total_maxlen))
    # Object representing the output
    seq = []
    # For each argument we are concatenating...
    for arg in args:
        # Start pasting into a position the starts at zero, and keeps
        # incrementing as we concatenate arguments
        placeholder_node = LLLnode.from_list(['add', placeholder, '_poz'],
                                             typ=ByteArrayType(total_maxlen),
                                             location='memory')
        placeholder_node_plus_32 = LLLnode.from_list(
            ['add', ['add', placeholder, '_poz'], 32],
            typ=ByteArrayType(total_maxlen),
            location='memory')
        if isinstance(arg.typ, ByteArrayType):
            # Ignore empty strings
            if arg.typ.maxlen == 0:
                continue
            # Get the length of the current argument
            if arg.location == "memory":
                length = LLLnode.from_list(['mload', '_arg'],
                                           typ=BaseType('int128'))
                argstart = LLLnode.from_list(['add', '_arg', 32],
                                             typ=arg.typ,
                                             location=arg.location)
            elif arg.location == "storage":
                length = LLLnode.from_list(['sload', ['sha3_32', '_arg']],
                                           typ=BaseType('int128'))
                argstart = LLLnode.from_list(['add', ['sha3_32', '_arg'], 1],
                                             typ=arg.typ,
                                             location=arg.location)
            # Make a copier to copy over data from that argyument
            seq.append([
                'with',
                '_arg',
                arg,
                [
                    'seq',
                    make_byte_slice_copier(placeholder_node_plus_32,
                                           argstart,
                                           length,
                                           arg.typ.maxlen,
                                           pos=getpos(expr)),
                    # Change the position to start at the correct
                    # place to paste the next value
                    ['set', '_poz', ['add', '_poz', length]]
                ]
            ])
        else:
            seq.append([
                'seq',
                [
                    'mstore', ['add', placeholder_node, 32],
                    unwrap_location(arg)
                ], ['set', '_poz', ['add', '_poz', 32]]
            ])
    # The position, after all arguments are processing, equals the total
    # length. Paste this in to make the output a proper bytearray
    seq.append(['mstore', placeholder, '_poz'])
    # Memory location of the output
    seq.append(placeholder)
    return LLLnode.from_list(['with', '_poz', 0, ['seq'] + seq],
                             typ=ByteArrayType(total_maxlen),
                             location='memory',
                             pos=getpos(expr),
                             annotation='concat')
Ejemplo n.º 26
0
def pack_arguments(signature, args, context, pos, return_placeholder=True):
    placeholder_typ = ByteArrayType(
        maxlen=sum([get_size_of_type(arg.typ)
                    for arg in signature.args]) * 32 + 32)
    placeholder = context.new_placeholder(placeholder_typ)
    setters = [['mstore', placeholder, signature.method_id]]
    needpos = False
    staticarray_offset = 0
    expected_arg_count = len(signature.args)
    actual_arg_count = len(args)
    if actual_arg_count != expected_arg_count:
        raise StructureException(
            "Wrong number of args for: %s (%s args, expected %s)" %
            (signature.name, actual_arg_count, expected_arg_count))

    for i, (arg,
            typ) in enumerate(zip(args, [arg.typ for arg in signature.args])):
        if isinstance(typ, BaseType):
            setters.append(
                make_setter(LLLnode.from_list(placeholder +
                                              staticarray_offset + 32 + i * 32,
                                              typ=typ),
                            arg,
                            'memory',
                            pos=pos))
        elif isinstance(typ, ByteArrayType):
            setters.append([
                'mstore', placeholder + staticarray_offset + 32 + i * 32,
                '_poz'
            ])
            arg_copy = LLLnode.from_list('_s',
                                         typ=arg.typ,
                                         location=arg.location)
            target = LLLnode.from_list(['add', placeholder + 32, '_poz'],
                                       typ=typ,
                                       location='memory')
            setters.append([
                'with', '_s', arg,
                [
                    'seq',
                    make_byte_array_copier(target, arg_copy, pos),
                    [
                        'set', '_poz',
                        [
                            'add', 32,
                            ['ceil32', ['add', '_poz',
                                        get_length(arg_copy)]]
                        ]
                    ]
                ]
            ])
            needpos = True
        elif isinstance(typ, ListType):
            target = LLLnode.from_list(
                [placeholder + 32 + staticarray_offset + i * 32],
                typ=typ,
                location='memory')
            setters.append(make_setter(target, arg, 'memory', pos=pos))
            staticarray_offset += 32 * (typ.count - 1)
        else:
            raise TypeMismatchException("Cannot pack argument of type %r" %
                                        typ)

    # For private call usage, doesn't use a returner.
    returner = [[placeholder + 28]] if return_placeholder else []
    if needpos:
        return (LLLnode.from_list([
            'with', '_poz',
            len(args) * 32 + staticarray_offset, ['seq'] + setters + returner
        ],
                                  typ=placeholder_typ,
                                  location='memory'),
                placeholder_typ.maxlen - 28, placeholder + 32)
    else:
        return (LLLnode.from_list(['seq'] + setters + returner,
                                  typ=placeholder_typ,
                                  location='memory'),
                placeholder_typ.maxlen - 28, placeholder + 32)
Ejemplo n.º 27
0
    def parse_for_list(self):
        from .parser import (
            parse_body,
            make_setter
        )

        iter_list_node = Expr(self.stmt.iter, self.context).lll_node
        if not isinstance(iter_list_node.typ.subtype, BaseType):  # Sanity check on list subtype.
            raise StructureException('For loops allowed only on basetype lists.', self.stmt.iter)
        iter_var_type = self.context.vars.get(self.stmt.iter.id).typ if isinstance(self.stmt.iter, ast.Name) else None
        subtype = iter_list_node.typ.subtype.typ
        varname = self.stmt.target.id
        value_pos = self.context.new_variable(varname, BaseType(subtype))
        i_pos = self.context.new_variable('_index_for_' + varname, BaseType(subtype))
        self.context.forvars[varname] = True
        if iter_var_type:  # Is a list that is already allocated to memory.
            self.context.set_in_for_loop(self.stmt.iter.id)  # make sure list cannot be altered whilst iterating.
            iter_var = self.context.vars.get(self.stmt.iter.id)
            body = [
                'seq',
                ['mstore', value_pos, ['mload', ['add', iter_var.pos, ['mul', ['mload', i_pos], 32]]]],
                parse_body(self.stmt.body, self.context)
            ]
            o = LLLnode.from_list(
                ['repeat', i_pos, 0, iter_var.size, body], typ=None, pos=getpos(self.stmt)
            )
            self.context.remove_in_for_loop(self.stmt.iter.id)
        elif isinstance(self.stmt.iter, ast.List):  # List gets defined in the for statement.
            # Allocate list to memory.
            count = iter_list_node.typ.count
            tmp_list = LLLnode.from_list(
                obj=self.context.new_placeholder(ListType(iter_list_node.typ.subtype, count)),
                typ=ListType(iter_list_node.typ.subtype, count),
                location='memory'
            )
            setter = make_setter(tmp_list, iter_list_node, 'memory', pos=getpos(self.stmt))
            body = [
                'seq',
                ['mstore', value_pos, ['mload', ['add', tmp_list, ['mul', ['mload', i_pos], 32]]]],
                parse_body(self.stmt.body, self.context)
            ]
            o = LLLnode.from_list(
                ['seq',
                    setter,
                    ['repeat', i_pos, 0, count, body]], typ=None, pos=getpos(self.stmt)
            )
        elif isinstance(self.stmt.iter, ast.Attribute):  # List is contained in storage.
            count = iter_list_node.typ.count
            self.context.set_in_for_loop(iter_list_node.annotation)  # make sure list cannot be altered whilst iterating.
            body = [
                'seq',
                ['mstore', value_pos, ['sload', ['add', ['sha3_32', iter_list_node], ['mload', i_pos]]]],
                parse_body(self.stmt.body, self.context),
            ]
            o = LLLnode.from_list(
                ['seq',
                    ['repeat', i_pos, 0, count, body]], typ=None, pos=getpos(self.stmt)
            )
            self.context.remove_in_for_loop(iter_list_node.annotation)
        del self.context.vars[varname]
        del self.context.vars['_index_for_' + varname]
        del self.context.forvars[varname]
        return o
Ejemplo n.º 28
0
    def from_definition(cls,
                        code,
                        sigs=None,
                        custom_units=None,
                        contract_def=False,
                        constant=False):
        name = code.name
        pos = 0

        if not is_varname_valid(name, custom_units=custom_units):
            raise FunctionDeclarationException("Function name invalid: " +
                                               name)
        # Determine the arguments, expects something of the form def foo(arg1: int128, arg2: int128 ...
        args = []
        for arg in code.args.args:
            typ = arg.annotation
            if not typ:
                raise InvalidTypeException("Argument must have type", arg)
            if not is_varname_valid(arg.arg, custom_units=custom_units):
                raise FunctionDeclarationException(
                    "Argument name invalid or reserved: " + arg.arg, arg)
            if arg.arg in (x.name for x in args):
                raise FunctionDeclarationException(
                    "Duplicate function argument name: " + arg.arg, arg)
            parsed_type = parse_type(typ,
                                     None,
                                     sigs,
                                     custom_units=custom_units)
            args.append(VariableRecord(arg.arg, pos, parsed_type, False))
            if isinstance(parsed_type, ByteArrayType):
                pos += 32
            else:
                pos += get_size_of_type(parsed_type) * 32

        # Apply decorators
        const, payable, private, public = False, False, False, False
        for dec in code.decorator_list:
            if isinstance(dec, ast.Name) and dec.id == "constant":
                const = True
            # TODO:
            # elif isinstance(dec, ast.Name) and dec.id == "payable":
            #     payable = True
            elif isinstance(dec, ast.Name) and dec.id == "private":
                private = True
            elif isinstance(dec, ast.Name) and dec.id == "public":
                public = True
            else:
                raise StructureException("Bad decorator", dec)

        if public and private:
            raise StructureException(
                "Cannot use public and private decorators on the same function: {}"
                .format(name))
        # if payable and const:
        #     raise StructureException("Function {} cannot be both constant and payable.".format(name))
        # if payable and private:
        #     raise StructureException("Function {} cannot be both private and payable.".format(name))
        if (not public and not private) and not contract_def:
            raise StructureException(
                "Function visibility must be declared (@public or @private)",
                code)
        if constant:
            const = True
        # Determine the return type and whether or not it's constant. Expects something
        # of the form:
        # def foo(): ...
        # def foo() -> int128: ...
        # If there is no return type, ie. it's of the form def foo(): ...
        # and NOT def foo() -> type: ..., then it's null
        if not code.returns:
            output_type = None
        elif isinstance(
                code.returns,
            (ast.Name, ast.Compare, ast.Subscript, ast.Call, ast.Tuple)):
            output_type = parse_type(code.returns,
                                     None,
                                     sigs,
                                     custom_units=custom_units)
        else:
            raise InvalidTypeException(
                "Output type invalid or unsupported: %r" %
                parse_type(code.returns, None),
                code.returns,
            )
        # Output type must be canonicalizable
        if output_type is not None:
            assert isinstance(output_type,
                              TupleType) or canonicalize_type(output_type)
        # Get the canonical function signature
        sig = cls.get_full_sig(name, code.args.args, sigs, custom_units)

        # Take the first 4 bytes of the hash of the sig to get the method ID
        method_id = fourbytes_to_int(sha3(bytes(sig, 'utf-8'))[:4])
        return cls(name, args, output_type, const, payable, private, sig,
                   method_id, custom_units)
Ejemplo n.º 29
0
    def parse_return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatchException("Not expecting to return a value", self.stmt)
            return LLLnode.from_list(self.make_return_stmt(0, 0), typ=None, pos=getpos(self.stmt))
        if not self.stmt.value:
            raise TypeMismatchException("Expecting to return a value", self.stmt)

        def zero_pad(bytez_placeholder, maxlen):
            zero_padder = LLLnode.from_list(['pass'])
            if maxlen > 0:
                zero_pad_i = self.context.new_placeholder(BaseType('uint256'))  # Iterator used to zero pad memory.
                zero_padder = LLLnode.from_list(
                    ['repeat', zero_pad_i, ['mload', bytez_placeholder], maxlen,
                        ['seq',
                            ['if', ['gt', ['mload', zero_pad_i], maxlen], 'break'],  # stay within allocated bounds
                            ['mstore8', ['add', ['add', 32, bytez_placeholder], ['mload', zero_pad_i]], 0]]],
                    annotation="Zero pad"
                )
            return zero_padder

        sub = Expr(self.stmt.value, self.context).lll_node
        self.context.increment_return_counter()
        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            if not isinstance(self.context.return_type, BaseType):
                raise TypeMismatchException("Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value)
            sub = unwrap_location(sub)
            if not are_units_compatible(sub.typ, self.context.return_type):
                raise TypeMismatchException("Return type units mismatch %r %r" % (sub.typ, self.context.return_type), self.stmt.value)
            elif sub.typ.is_literal and (self.context.return_type.typ == sub.typ or
                    'int' in self.context.return_type.typ and
                    'int' in sub.typ.typ):
                if not SizeLimits.in_bounds(self.context.return_type.typ, sub.value):
                    raise InvalidLiteralException("Number out of range: " + str(sub.value), self.stmt)
                else:
                    return LLLnode.from_list(['seq', ['mstore', 0, sub], self.make_return_stmt(0, 32)], typ=None, pos=getpos(self.stmt))
            elif is_base_type(sub.typ, self.context.return_type.typ) or \
                    (is_base_type(sub.typ, 'int128') and is_base_type(self.context.return_type, 'int256')):
                return LLLnode.from_list(['seq', ['mstore', 0, sub], self.make_return_stmt(0, 32)], typ=None, pos=getpos(self.stmt))
            else:
                raise TypeMismatchException("Unsupported type conversion: %r to %r" % (sub.typ, self.context.return_type), self.stmt.value)
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayType):
            if not isinstance(self.context.return_type, ByteArrayType):
                raise TypeMismatchException("Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value)
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatchException("Cannot cast from greater max-length %d to shorter max-length %d" %
                                            (sub.typ.maxlen, self.context.return_type.maxlen), self.stmt.value)

            loop_memory_position = self.context.new_placeholder(typ=BaseType('uint256'))  # loop memory has to be allocated first.
            len_placeholder = self.context.new_placeholder(typ=BaseType('uint256'))  # len & bytez placeholder have to be declared after each other at all times.
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ('storage', 'memory'):
                return LLLnode.from_list([
                    'seq',
                    make_byte_array_copier(
                        LLLnode(bytez_placeholder, location='memory', typ=sub.typ),
                        sub,
                        pos=getpos(self.stmt)
                    ),
                    zero_pad(bytez_placeholder, sub.typ.maxlen),
                    ['mstore', len_placeholder, 32],
                    self.make_return_stmt(len_placeholder, ['ceil32', ['add', ['mload', bytez_placeholder], 64]], loop_memory_position=loop_memory_position)],
                    typ=None, pos=getpos(self.stmt)
                )
            else:
                raise Exception("Invalid location: %s" % sub.location)

        elif isinstance(sub.typ, ListType):
            sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0]
            ret_base_type = re.split(r'\(|\[', str(self.context.return_type.subtype))[0]
            loop_memory_position = self.context.new_placeholder(typ=BaseType('uint256'))
            if sub_base_type != ret_base_type:
                raise TypeMismatchException(
                    "List return type %r does not match specified return type, expecting %r" % (
                        sub_base_type, ret_base_type
                    ),
                    self.stmt
                )
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(self.make_return_stmt(sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position),
                                            typ=None, pos=getpos(self.stmt))
            else:
                new_sub = LLLnode.from_list(self.context.new_placeholder(self.context.return_type), typ=self.context.return_type, location='memory')
                setter = make_setter(new_sub, sub, 'memory', pos=getpos(self.stmt))
                return LLLnode.from_list(['seq', setter, self.make_return_stmt(new_sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position)],
                                            typ=None, pos=getpos(self.stmt))

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if not isinstance(self.context.return_type, TupleType):
                raise TypeMismatchException("Trying to return tuple type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value)

            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!", self.stmt)

            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member, NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    raise StructureException(
                        "Tuple return type does not match annotated return. {} != {}".format(
                            type(sub_type), type(ret_x)
                        ),
                        self.stmt
                    )

            # Is from a call expression.
            if len(sub.args[0].args) > 0 and sub.args[0].args[0].value == 'call':  # self-call to public.
                mem_pos = sub.args[0].args[-1]
                mem_size = get_size_of_type(sub.typ) * 32
                return LLLnode.from_list(['return', mem_pos, mem_size], typ=sub.typ)

            elif (sub.annotation and 'Internal Call' in sub.annotation):
                mem_pos = sub.args[-1].value if sub.value == 'seq_unchecked' else sub.args[0].args[-1]
                mem_size = get_size_of_type(sub.typ) * 32
                # Add zero padder if bytes are present in output.
                zero_padder = ['pass']
                byte_arrays = [(i, x) for i, x in enumerate(sub.typ.members) if isinstance(x, ByteArrayType)]
                if byte_arrays:
                    i, x = byte_arrays[-1]
                    zero_padder = zero_pad(bytez_placeholder=['add', mem_pos, ['mload', mem_pos + i * 32]], maxlen=x.maxlen)
                return LLLnode.from_list(
                    ['seq'] +
                    [sub] +
                    [zero_padder] +
                    [self.make_return_stmt(mem_pos, mem_size)
                ], typ=sub.typ, pos=getpos(self.stmt))

            subs = []
            # Pre-allocate loop_memory_position if required for private function returning.
            loop_memory_position = self.context.new_placeholder(typ=BaseType('uint256')) if self.context.is_private else None
            # Allocate dynamic off set counter, to keep track of the total packed dynamic data size.
            dynamic_offset_counter_placeholder = self.context.new_placeholder(typ=BaseType('uint256'))
            dynamic_offset_counter = LLLnode(
                dynamic_offset_counter_placeholder, typ=None, annotation="dynamic_offset_counter"  # dynamic offset position counter.
            )
            new_sub = LLLnode.from_list(
                self.context.new_placeholder(typ=BaseType('uint256')), typ=self.context.return_type, location='memory', annotation='new_sub'
            )
            keyz = list(range(len(sub.typ.members)))
            dynamic_offset_start = 32 * len(sub.args)  # The static list of args end.
            left_token = LLLnode.from_list('_loc', typ=new_sub.typ, location="memory")

            def get_dynamic_offset_value():
                # Get value of dynamic offset counter.
                return ['mload', dynamic_offset_counter]

            def increment_dynamic_offset(dynamic_spot):
                # Increment dyanmic offset counter in memory.
                return [
                    'mstore', dynamic_offset_counter,
                    ['add',
                        ['add', ['ceil32', ['mload', dynamic_spot]], 32],
                        ['mload', dynamic_offset_counter]]
                ]

            for i, typ in enumerate(keyz):
                arg = sub.args[i]
                variable_offset = LLLnode.from_list(['add', 32 * i, left_token], typ=arg.typ, annotation='variable_offset')
                if isinstance(arg.typ, ByteArrayType):
                    # Store offset pointer value.
                    subs.append(['mstore', variable_offset, get_dynamic_offset_value()])

                    # Store dynamic data, from offset pointer onwards.
                    dynamic_spot = LLLnode.from_list(['add', left_token, get_dynamic_offset_value()], location="memory", typ=arg.typ, annotation='dynamic_spot')
                    subs.append(make_setter(dynamic_spot, arg, location="memory", pos=getpos(self.stmt)))
                    subs.append(increment_dynamic_offset(dynamic_spot))

                elif isinstance(arg.typ, BaseType):
                    subs.append(make_setter(variable_offset, arg, "memory", pos=getpos(self.stmt)))
                else:
                    raise Exception("Can't return type %s as part of tuple", type(arg.typ))

            setter = LLLnode.from_list(
                ['seq',
                    ['mstore', dynamic_offset_counter, dynamic_offset_start],
                    ['with', '_loc', new_sub, ['seq'] + subs]],
                typ=None
            )

            return LLLnode.from_list(
                ['seq',
                    setter,
                    self.make_return_stmt(new_sub, get_dynamic_offset_value(), loop_memory_position)],
                typ=None, pos=getpos(self.stmt)
            )
        else:
            raise TypeMismatchException("Can only return base type!", self.stmt)
Ejemplo n.º 30
0
    def compare(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.comparators[0], self.context)

        if isinstance(left.typ, ByteArrayType) and isinstance(right.typ, ByteArrayType):
            if left.typ.maxlen != right.typ.maxlen:
                raise TypeMismatchException('Can only compare bytes of the same length', self.expr)
            if left.typ.maxlen > 32 or right.typ.maxlen > 32:
                raise ParserException('Can only compare bytes of length shorter than 32 bytes', self.expr)
        elif isinstance(self.expr.ops[0], ast.In) and \
           isinstance(right.typ, ListType):
            if not are_units_compatible(left.typ, right.typ.subtype) and not are_units_compatible(right.typ.subtype, left.typ):
                raise TypeMismatchException("Can't use IN comparison with different types!", self.expr)
            return self.build_in_comparator()
        else:
            if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ):
                raise TypeMismatchException("Can't compare values with different units!", self.expr)

        if len(self.expr.ops) != 1:
            raise StructureException("Cannot have a comparison with more than two elements", self.expr)
        if isinstance(self.expr.ops[0], ast.Gt):
            op = 'sgt'
        elif isinstance(self.expr.ops[0], ast.GtE):
            op = 'sge'
        elif isinstance(self.expr.ops[0], ast.LtE):
            op = 'sle'
        elif isinstance(self.expr.ops[0], ast.Lt):
            op = 'slt'
        elif isinstance(self.expr.ops[0], ast.Eq):
            op = 'eq'
        elif isinstance(self.expr.ops[0], ast.NotEq):
            op = 'ne'
        else:
            raise Exception("Unsupported comparison operator")

        # Compare (limited to 32) byte arrays.
        if isinstance(left.typ, ByteArrayType) and isinstance(left.typ, ByteArrayType):
            left = Expr(self.expr.left, self.context).lll_node
            right = Expr(self.expr.comparators[0], self.context).lll_node

            def load_bytearray(side):
                if side.location == 'memory':
                    return ['mload', ['add', 32, side]]
                elif side.location == 'storage':
                    return ['sload', ['add', 1, ['sha3_32', side]]]

            return LLLnode.from_list(
                [op, load_bytearray(left), load_bytearray(right)], typ='bool', pos=getpos(self.expr))

        # Compare other types.
        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            if op not in ('eq', 'ne'):
                raise TypeMismatchException("Invalid type for comparison op", self.expr)
        left_type, right_type = left.typ.typ, right.typ.typ

        # Special Case: comparison of a literal integer. If in valid range allow it to be compared.
        if {left_type, right_type} == {'int128', 'uint256'} and {left.typ.is_literal, right.typ.is_literal} == {True, False}:

            comparison_allowed = False
            if left.typ.is_literal and SizeLimits.in_bounds(right_type, left.value):
                comparison_allowed = True
            elif right.typ.is_literal and SizeLimits.in_bounds(left_type, right.value):
                comparison_allowed = True
            op = self._signed_to_unsigned_comparision_op(op)

            if comparison_allowed:
                return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))

        elif {left_type, right_type} == {'uint256', 'uint256'}:
            op = self._signed_to_unsigned_comparision_op(op)
        elif (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type:
            raise TypeMismatchException(
                'Implicit conversion from {} to {} disallowed, please convert.'.format(left_type, right_type),
                self.expr
            )

        if left_type == right_type:
            return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))
        else:
            raise TypeMismatchException("Unsupported types for comparison: %r %r" % (left_type, right_type), self.expr)