Esempio n. 1
0
    def parse_Return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                return
            return LLLnode.from_list(
                make_return_stmt(self.stmt, self.context, 0, 0),
                typ=None,
                pos=getpos(self.stmt),
                valency=0,
            )

        sub = Expr(self.stmt.value, self.context).lll_node

        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            sub = unwrap_location(sub)

            if self.context.return_type != sub.typ and not sub.typ.is_literal:
                return
            elif sub.typ.is_literal and (
                    self.context.return_type.typ == sub.typ
                    or "int" in self.context.return_type.typ
                    and "int" in sub.typ.typ):  # noqa: E501
                if SizeLimits.in_bounds(self.context.return_type.typ,
                                        sub.value):
                    return LLLnode.from_list(
                        [
                            "seq",
                            ["mstore", 0, sub],
                            make_return_stmt(self.stmt, self.context, 0, 32),
                        ],
                        typ=None,
                        pos=getpos(self.stmt),
                        valency=0,
                    )
            elif is_base_type(sub.typ, self.context.return_type.typ) or (
                    is_base_type(sub.typ, "int128") and is_base_type(
                        self.context.return_type, "int256")):  # noqa: E501
                return LLLnode.from_list(
                    [
                        "seq", ["mstore", 0, sub],
                        make_return_stmt(self.stmt, self.context, 0, 32)
                    ],
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            return
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayLike):
            if not sub.typ.eq_base(self.context.return_type):
                return
            if sub.typ.maxlen > self.context.return_type.maxlen:
                return

            # loop memory has to be allocated first.
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType("uint256"))
            # len & bytez placeholder have to be declared after each other at all times.
            len_placeholder = self.context.new_placeholder(
                typ=BaseType("uint256"))
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ("storage", "memory"):
                return LLLnode.from_list(
                    [
                        "seq",
                        make_byte_array_copier(
                            LLLnode(bytez_placeholder,
                                    location="memory",
                                    typ=sub.typ),
                            sub,
                            pos=getpos(self.stmt),
                        ),
                        zero_pad(bytez_placeholder),
                        ["mstore", len_placeholder, 32],
                        make_return_stmt(
                            self.stmt,
                            self.context,
                            len_placeholder,
                            [
                                "ceil32",
                                ["add", ["mload", bytez_placeholder], 64]
                            ],
                            loop_memory_position=loop_memory_position,
                        ),
                    ],
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            return

        elif isinstance(sub.typ, ListType):
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType("uint256"))
            if sub.typ != self.context.return_type:
                return
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    ),
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                new_sub = LLLnode.from_list(
                    self.context.new_placeholder(self.context.return_type),
                    typ=self.context.return_type,
                    location="memory",
                )
                setter = make_setter(new_sub,
                                     sub,
                                     "memory",
                                     pos=getpos(self.stmt))
                return LLLnode.from_list(
                    [
                        "seq",
                        setter,
                        make_return_stmt(
                            self.stmt,
                            self.context,
                            new_sub,
                            get_size_of_type(self.context.return_type) * 32,
                            loop_memory_position=loop_memory_position,
                        ),
                    ],
                    typ=None,
                    pos=getpos(self.stmt),
                )

        # Returning a struct
        elif isinstance(sub.typ, StructType):
            retty = self.context.return_type
            if isinstance(retty, StructType) and retty.name == sub.typ.name:
                return gen_tuple_return(self.stmt, self.context, sub)

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if not isinstance(self.context.return_type, TupleType):
                return

            if len(self.context.return_type.members) != len(sub.typ.members):
                return

            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member,
                                                  NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    return
            return gen_tuple_return(self.stmt, self.context, sub)
Esempio n. 2
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):
    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:
        if expected_arg == 'num_literal':
            if context.constants.is_constant_of_base_type(arg, ('uint256', 'int128')):
                return context.constants.get_constant(arg.id, None).value
            if isinstance(arg, (vy_ast.Int, vy_ast.Decimal)):
                return arg.n
        elif expected_arg == 'str_literal':
            if isinstance(arg, vy_ast.Str):
                bytez = b''
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteral(
                            f"Cannot insert special character {c} into byte array",
                            arg,
                        )
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == 'bytes_literal':
            if isinstance(arg, vy_ast.Bytes):
                return arg.s
        elif expected_arg == 'name_literal':
            if isinstance(arg, vy_ast.Name):
                return arg.id
            elif isinstance(arg, vy_ast.Subscript) and arg.value.id == 'bytes':
                return f'bytes[{arg.slice.value.n}]'
        elif expected_arg == '*':
            return arg
        elif expected_arg == 'bytes':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == 'string':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            # Does not work for unit-endowed types inside compound types, e.g. timestamp[2]
            parsed_expected_type = context.parse_type(
                vy_ast.parse_to_ast(expected_arg)[0].value,
                'memory',
            )
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)

                is_valid_integer = (
                    (
                        expected_arg in ('int128', 'uint256') and isinstance(vsub.typ, BaseType)
                    ) and (
                        vsub.typ.typ in ('int128', 'uint256') and vsub.typ.is_literal
                    ) and (
                        SizeLimits.in_bounds(expected_arg, vsub.value)
                    )
                )

                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif is_valid_integer:
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).lll_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).lll_node
    if len(expected_arg_typelist) == 1:
        raise TypeMismatch(
            f"Expecting {expected_arg} for argument {index} of {function_name}",
            arg
        )
    else:
        raise TypeMismatch(
            f"Expecting one of {expected_arg_typelist} for argument {index} of {function_name}",
            arg
        )
Esempio n. 3
0
def base_type_conversion(orig, frm, to, pos):
    orig = unwrap_location(orig)
    if getattr(frm, 'is_literal', False) and frm.typ in ('int128', 'uint256') and not SizeLimits.in_bounds(frm.typ, orig.value):
        raise InvalidLiteralException("Number out of range: " + str(orig.value), pos)
    if not isinstance(frm, (BaseType, NullType)) or not isinstance(to, BaseType):
        raise TypeMismatchException("Base type conversion from or to non-base type: %r %r" % (frm, to), pos)
    elif is_base_type(frm, to.typ) and are_units_compatible(frm, to):
        return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate)
    elif is_base_type(frm, 'int128') and is_base_type(to, 'decimal') and are_units_compatible(frm, to):
        return LLLnode.from_list(['mul', orig, DECIMAL_DIVISOR], typ=BaseType('decimal', to.unit, to.positional))
    elif is_base_type(frm, 'uint256') and is_base_type(to, 'int128') and are_units_compatible(frm, to):
        return LLLnode.from_list(['uclample', orig, ['mload', MemoryPositions.MAXNUM]], typ=BaseType('int128'))
    elif isinstance(frm, NullType):
        if to.typ not in ('int128', 'bool', 'uint256', 'address', 'bytes32', 'decimal'):
            # This is only to future proof the use of  base_type_conversion.
            raise TypeMismatchException("Cannot convert null-type object to type %r" % to, pos)  # pragma: no cover
        return LLLnode.from_list(0, typ=to)
    elif isinstance(to, ContractType) and frm.typ == 'address':
        return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate)
    # Integer literal conversion.
    elif (frm.typ, to.typ, frm.is_literal) == ('int128', 'uint256', True):
        return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate)
    else:
        raise TypeMismatchException("Typecasting from base type %r to %r unavailable" % (frm, to), pos)
Esempio n. 4
0
def to_uint256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)
    _unit = in_arg.typ.unit if input_type in ('int128', 'decimal') else None

    if input_type == 'num_literal':
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds('uint256', in_arg):
                raise InvalidLiteralException(f"Number out of range: {in_arg}")
            return LLLnode.from_list(
                in_arg,
                typ=BaseType('uint256', _unit),
                pos=getpos(expr)
            )
        elif isinstance(in_arg, float):
            if not SizeLimits.in_bounds('uint256', math.trunc(in_arg)):
                raise InvalidLiteralException(f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(
                math.trunc(in_arg),
                typ=BaseType('uint256', _unit),
                pos=getpos(expr)
            )
        else:
            raise InvalidLiteralException(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type == 'int128':
        return LLLnode.from_list(
            ['clampge', in_arg, 0],
            typ=BaseType('uint256', _unit),
            pos=getpos(expr)
        )

    elif isinstance(in_arg, LLLnode) and input_type == 'decimal':
        return LLLnode.from_list(
            ['div', ['clampge', in_arg, 0], DECIMAL_DIVISOR],
            typ=BaseType('uint256', _unit),
            pos=getpos(expr)
        )

    elif isinstance(in_arg, LLLnode) and input_type == 'bool':
        return LLLnode.from_list(
            in_arg,
            typ=BaseType('uint256'),
            pos=getpos(expr)
        )

    elif isinstance(in_arg, LLLnode) and input_type in ('bytes32', 'address'):
        return LLLnode(
            value=in_arg.value,
            args=in_arg.args,
            typ=BaseType('uint256'),
            pos=getpos(expr)
        )

    elif isinstance(in_arg, LLLnode) and input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise InvalidLiteralException(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to uint256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, 'uint256')

    else:
        raise InvalidLiteralException(f"Invalid input for uint256: {in_arg}", expr)
Esempio n. 5
0
def to_int128(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)
    _unit = in_arg.typ.unit if input_type in ('uint256', 'decimal') else None

    if input_type == 'num_literal':
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds('int128', in_arg):
                raise InvalidLiteralException(f"Number out of range: {in_arg}")
            return LLLnode.from_list(
                in_arg,
                typ=BaseType('int128', _unit),
                pos=getpos(expr)
            )
        elif isinstance(in_arg, float):
            if not SizeLimits.in_bounds('int128', math.trunc(in_arg)):
                raise InvalidLiteralException(f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(
                math.trunc(in_arg),
                typ=BaseType('int128', _unit),
                pos=getpos(expr)
            )
        else:
            raise InvalidLiteralException(f"Unknown numeric literal type: {in_arg}")

    elif input_type == 'bytes32':
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds('int128', in_arg.value):
                raise InvalidLiteralException(f"Number out of range: {in_arg.value}", expr)
            else:
                return LLLnode.from_list(
                    in_arg,
                    typ=BaseType('int128', _unit),
                    pos=getpos(expr)
                )
        else:
            return LLLnode.from_list(
                [
                    'clamp',
                    ['mload', MemoryPositions.MINNUM],
                    in_arg,
                    ['mload', MemoryPositions.MAXNUM],
                ],
                typ=BaseType('int128', _unit),
                pos=getpos(expr)
            )

    elif input_type == 'address':
        return LLLnode.from_list(
            [
                'signextend',
                15,
                [
                    'and',
                    in_arg,
                    (SizeLimits.ADDRSIZE - 1)
                ],
            ],
            typ=BaseType('int128', _unit),
            pos=getpos(expr)
        )

    elif input_type in ('string', 'bytes'):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatchException(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128",
                expr,
            )
        return byte_array_to_num(in_arg, expr, 'int128')

    elif input_type == 'uint256':
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds('int128', in_arg.value):
                raise InvalidLiteralException(f"Number out of range: {in_arg.value}", expr)
            else:
                return LLLnode.from_list(
                    in_arg,
                    typ=BaseType('int128', _unit),
                    pos=getpos(expr)
                )

        else:
            return LLLnode.from_list(
                ['uclample', in_arg, ['mload', MemoryPositions.MAXNUM]],
                typ=BaseType('int128', _unit),
                pos=getpos(expr)
            )

    elif input_type == 'decimal':
        return LLLnode.from_list(
            [
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                ['sdiv', in_arg, DECIMAL_DIVISOR],
                ['mload', MemoryPositions.MAXNUM],
            ],
            typ=BaseType('int128', _unit),
            pos=getpos(expr)
        )

    elif input_type == 'bool':
        return LLLnode.from_list(
            in_arg,
            typ=BaseType('int128', _unit),
            pos=getpos(expr)
        )

    else:
        raise InvalidLiteralException(f"Invalid input for int128: {in_arg}", expr)
Esempio n. 6
0
File: expr.py Progetto: zutobg/vyper
    def arithmetic(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)
        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatchException("Unsupported types for arithmetic op: %r %r" % (left.typ, right.typ), self.expr)

        arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Special Case: Simplify any literal to literal arithmetic at compile time.
        if left.typ.is_literal and right.typ.is_literal and \
           isinstance(right.value, int) and isinstance(left.value, int):

            if isinstance(self.expr.op, ast.Add):
                val = left.value + right.value
            elif isinstance(self.expr.op, ast.Sub):
                val = left.value - right.value
            elif isinstance(self.expr.op, ast.Mult):
                val = left.value * right.value
            elif isinstance(self.expr.op, ast.Div):
                val = left.value // right.value
            elif isinstance(self.expr.op, ast.Mod):
                val = left.value % right.value
            elif isinstance(self.expr.op, ast.Pow):
                val = left.value ** right.value
            else:
                raise ParserException('Unsupported literal operator: %s' % str(type(self.expr.op)), self.expr)

            num = ast.Num(val)
            num.source_code = self.expr.source_code
            num.lineno = self.expr.lineno
            num.col_offset = self.expr.col_offset

            return Expr.parse_value_expr(num, self.context)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value):
                right = LLLnode.from_list(right.value, typ=BaseType('uint256', None, is_literal=True), pos=getpos(self.expr))
                arithmetic_pair = {left.typ.typ, right.typ.typ}
            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value):
                left = LLLnode.from_list(left.value, typ=BaseType('uint256', None, is_literal=True), pos=getpos(self.expr))
                arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatchException("Cannot implicitly convert {} to {}.".format(left.typ.typ, right.typ.typ), self.expr)

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (ast.Add, ast.Sub)):
            if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None:
                raise TypeMismatchException("Unit mismatch: %r %r" % (left.typ.unit, right.typ.unit), self.expr)
            if left.typ.positional and right.typ.positional and isinstance(self.expr.op, ast.Add):
                raise TypeMismatchException("Cannot add two positional units!", self.expr)
            new_unit = left.typ.unit or right.typ.unit
            new_positional = left.typ.positional ^ right.typ.positional  # xor, as subtracting two positionals gives a delta
            op = 'add' if isinstance(self.expr.op, ast.Add) else 'sub'
            if ltyp == 'uint256' and isinstance(self.expr.op, ast.Add):
                o = LLLnode.from_list(['seq',
                                # Checks that: a + b >= a
                                ['assert', ['ge', ['add', left, right], left]],
                                ['add', left, right]], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr))
            elif ltyp == 'uint256' and isinstance(self.expr.op, ast.Sub):
                o = LLLnode.from_list(['seq',
                                # Checks that: a >= b
                                ['assert', ['ge', left, right]],
                                ['sub', left, right]], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list([op, left, right], typ=BaseType(ltyp, new_unit, new_positional), pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation '%r(%r, %r)'" % (op, ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Mult):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot multiply positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(['if', ['eq', left, 0], [0],
                                      ['seq', ['assert', ['eq', ['div', ['mul', left, right], left], right]],
                                      ['mul', left, right]]], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(['mul', left, right], typ=BaseType('int128', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list(['with', 'r', right, ['with', 'l', left,
                                        ['with', 'ans', ['mul', 'l', 'r'],
                                            ['seq',
                                                ['assert', ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['iszero', 'l']]],
                                                ['sdiv', 'ans', DECIMAL_DIVISOR]]]]], typ=BaseType('decimal', new_unit), pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation 'mul(%r, %r)'" % (ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Div):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot divide positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit, div=True)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(['seq',
                                # Checks that:  b != 0
                                ['assert', right],
                                ['div', left, right]], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(['sdiv', left, ['clamp_nonzero', right]], typ=BaseType('int128', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list(['with', 'l', left, ['with', 'r', ['clamp_nonzero', right],
                                            ['sdiv', ['mul', 'l', DECIMAL_DIVISOR], 'r']]],
                                      typ=BaseType('decimal', new_unit), pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation 'div(%r, %r)'" % (ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Mod):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot use positional values as modulus arguments!", self.expr)
            if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None:
                raise TypeMismatchException("Modulus arguments must have same unit", self.expr)
            new_unit = left.typ.unit or right.typ.unit
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(['seq',
                                ['assert', right],
                                ['mod', left, right]], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list(['smod', left, ['clamp_nonzero', right]], typ=BaseType(ltyp, new_unit), pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation 'mod(%r, %r)'" % (ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Pow):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot use positional values as exponential arguments!", self.expr)
            if right.typ.unit:
                raise TypeMismatchException("Cannot use unit values as exponents", self.expr)
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, ast.Name):
                raise TypeMismatchException("Cannot use dynamic values as exponents, for unit base types", self.expr)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(['seq',
                                        ['assert', ['or', ['or', ['eq', right, 1], ['iszero', right]],
                                        ['lt', left, ['exp', left, right]]]],
                                        ['exp', left, right]], typ=BaseType('uint256'), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                new_unit = left.typ.unit
                if left.typ.unit and not isinstance(self.expr.right, ast.Name):
                    new_unit = {left.typ.unit.copy().popitem()[0]: self.expr.right.n}
                o = LLLnode.from_list(['exp', left, right], typ=BaseType('int128', new_unit), pos=getpos(self.expr))
            else:
                raise TypeMismatchException('Only whole number exponents are supported', self.expr)
        else:
            raise Exception("Unsupported binop: %r" % self.expr.op)
        if o.typ.typ == 'int128':
            return LLLnode.from_list(['clamp', ['mload', MemoryPositions.MINNUM], o, ['mload', MemoryPositions.MAXNUM]], typ=o.typ, pos=getpos(self.expr))
        elif o.typ.typ == 'decimal':
            return LLLnode.from_list(['clamp', ['mload', MemoryPositions.MINDECIMAL], o, ['mload', MemoryPositions.MAXDECIMAL]], typ=o.typ, pos=getpos(self.expr))
        if o.typ.typ == 'uint256':
            return o
        else:
            raise Exception("%r %r" % (o, o.typ))
Esempio n. 7
0
def make_setter(left, right, location, pos, in_function_call=False):
    # Basic types
    if isinstance(left.typ, BaseType):
        right = base_type_conversion(
            right,
            right.typ,
            left.typ,
            pos,
            in_function_call=in_function_call,
        )
        # TODO this overlaps a type check in parser.stmt.Stmt._check_valid_assign
        # and should be examined during a refactor (@iamdefinitelyahuman)
        if 'int' in left.typ.typ and isinstance(right.value, int):
            if not SizeLimits.in_bounds(left.typ.typ, right.value):
                raise InvalidLiteralException(
                    f"Number out of range for {left.typ}: {right.value}", pos)
        if location == 'storage':
            return LLLnode.from_list(['sstore', left, right], typ=None)
        elif location == 'memory':
            return LLLnode.from_list(['mstore', left, right], typ=None)
    # Byte arrays
    elif isinstance(left.typ, ByteArrayLike):
        return make_byte_array_copier(left, right, pos)
    # Can't copy mappings
    elif isinstance(left.typ, MappingType):
        raise TypeMismatchException(
            "Cannot copy mappings; can only copy individual elements", pos)
    # Arrays
    elif isinstance(left.typ, ListType):
        # Cannot do something like [a, b, c] = [1, 2, 3]
        if left.value == "multi":
            raise Exception("Target of set statement must be a single item")
        if not isinstance(right.typ, (ListType, NullType)):
            raise TypeMismatchException(
                f"Setter type mismatch: left side is array, right side is {right.typ}",
                pos)
        left_token = LLLnode.from_list('_L',
                                       typ=left.typ,
                                       location=left.location)
        if left.location == "storage":
            left = LLLnode.from_list(['sha3_32', left],
                                     typ=left.typ,
                                     location="storage_prehashed")
            left_token.location = "storage_prehashed"
        # Type checks
        if not isinstance(right.typ, NullType):
            if not isinstance(right.typ, ListType):
                raise TypeMismatchException(
                    "Left side is array, right side is not", pos)
            if left.typ.count != right.typ.count:
                raise TypeMismatchException("Mismatched number of elements",
                                            pos)
        # If the right side is a literal
        if right.value == "multi":
            if len(right.args) != left.typ.count:
                raise TypeMismatchException("Mismatched number of elements",
                                            pos)
            subs = []
            for i in range(left.typ.count):
                subs.append(
                    make_setter(add_variable_offset(
                        left_token,
                        LLLnode.from_list(i, typ='int128'),
                        pos=pos,
                        array_bounds_check=False,
                    ),
                                right.args[i],
                                location,
                                pos=pos))
            return LLLnode.from_list(['with', '_L', left, ['seq'] + subs],
                                     typ=None)
        # If the right side is a null
        # CC 20190619 probably not needed as of #1106
        elif isinstance(right.typ, NullType):
            subs = []
            for i in range(left.typ.count):
                subs.append(
                    make_setter(add_variable_offset(
                        left_token,
                        LLLnode.from_list(i, typ='int128'),
                        pos=pos,
                        array_bounds_check=False,
                    ),
                                LLLnode.from_list(None, typ=NullType()),
                                location,
                                pos=pos))
            return LLLnode.from_list(['with', '_L', left, ['seq'] + subs],
                                     typ=None)
        # If the right side is a variable
        else:
            right_token = LLLnode.from_list('_R',
                                            typ=right.typ,
                                            location=right.location)
            subs = []
            for i in range(left.typ.count):
                subs.append(
                    make_setter(add_variable_offset(
                        left_token,
                        LLLnode.from_list(i, typ='int128'),
                        pos=pos,
                        array_bounds_check=False,
                    ),
                                add_variable_offset(
                                    right_token,
                                    LLLnode.from_list(i, typ='int128'),
                                    pos=pos,
                                    array_bounds_check=False,
                                ),
                                location,
                                pos=pos))
            return LLLnode.from_list(
                ['with', '_L', left, ['with', '_R', right, ['seq'] + subs]],
                typ=None)
    # Structs
    elif isinstance(left.typ, TupleLike):
        if left.value == "multi" and isinstance(left.typ, StructType):
            raise Exception("Target of set statement must be a single item")
        if not isinstance(right.typ, NullType):
            if not isinstance(right.typ, left.typ.__class__):
                raise TypeMismatchException(
                    f"Setter type mismatch: left side is {left.typ}, right side is {right.typ}",
                    pos,
                )
            if isinstance(left.typ, StructType):
                for k in right.args:
                    if k.value is None:
                        raise InvalidLiteralException(
                            'Setting struct value to None is not allowed, use a default value.',
                            pos,
                        )
                for k in left.typ.members:
                    if k not in right.typ.members:
                        raise TypeMismatchException(
                            f"Keys don't match for structs, missing {k}",
                            pos,
                        )
                for k in right.typ.members:
                    if k not in left.typ.members:
                        raise TypeMismatchException(
                            f"Keys don't match for structs, extra {k}",
                            pos,
                        )
                if left.typ.name != right.typ.name:
                    raise TypeMismatchException(
                        f"Expected {left.typ}, got {right.typ}", pos)
            else:
                if len(left.typ.members) != len(right.typ.members):
                    raise TypeMismatchException(
                        "Tuple lengths don't match, "
                        f"{len(left.typ.members)} vs {len(right.typ.members)}",
                        pos,
                    )

        left_token = LLLnode.from_list('_L',
                                       typ=left.typ,
                                       location=left.location)
        if left.location == "storage":
            left = LLLnode.from_list(['sha3_32', left],
                                     typ=left.typ,
                                     location="storage_prehashed")
            left_token.location = "storage_prehashed"
        keyz = left.typ.tuple_keys()

        # If the left side is a literal
        if left.value == 'multi':
            locations = [arg.location for arg in left.args]
        else:
            locations = [location for _ in keyz]

        # If the right side is a literal
        if right.value == "multi":
            if len(right.args) != len(keyz):
                raise TypeMismatchException("Mismatched number of elements",
                                            pos)
            # get the RHS arguments into a dict because
            # they are not guaranteed to be in the same order
            # the LHS keys.
            right_args = dict(zip(right.typ.tuple_keys(), right.args))
            subs = []
            for (key, loc) in zip(keyz, locations):
                subs.append(
                    make_setter(
                        add_variable_offset(left_token, key, pos=pos),
                        right_args[key],
                        loc,
                        pos=pos,
                    ))
            return LLLnode.from_list(['with', '_L', left, ['seq'] + subs],
                                     typ=None)
        # If the right side is a null
        elif isinstance(right.typ, NullType):
            subs = []
            for typ, loc in zip(keyz, locations):
                subs.append(
                    make_setter(
                        add_variable_offset(left_token, typ, pos=pos),
                        LLLnode.from_list(None, typ=NullType()),
                        loc,
                        pos=pos,
                    ))
            return LLLnode.from_list(['with', '_L', left, ['seq'] + subs],
                                     typ=None)
        # If tuple assign.
        elif isinstance(left.typ, TupleType) and isinstance(
                right.typ, TupleType):
            subs = []
            static_offset_counter = 0
            zipped_components = zip(left.args, right.typ.members, locations)
            for var_arg in left.args:
                if var_arg.location == 'calldata':
                    raise ConstancyViolationException(
                        f"Cannot modify function argument: {var_arg.annotation}",
                        pos)
            for left_arg, right_arg, loc in zipped_components:
                if isinstance(right_arg, ByteArrayLike):
                    RType = ByteArrayType if isinstance(
                        right_arg, ByteArrayType) else StringType
                    offset = LLLnode.from_list([
                        'add', '_R',
                        ['mload', ['add', '_R', static_offset_counter]]
                    ],
                                               typ=RType(right_arg.maxlen),
                                               location='memory',
                                               pos=pos)
                    static_offset_counter += 32
                else:
                    offset = LLLnode.from_list(
                        ['mload', ['add', '_R', static_offset_counter]],
                        typ=right_arg.typ,
                        pos=pos,
                    )
                    static_offset_counter += get_size_of_type(right_arg) * 32
                subs.append(make_setter(left_arg, offset, loc, pos=pos))
            return LLLnode.from_list(
                ['with', '_R', right, ['seq'] + subs],
                typ=None,
                annotation='Tuple assignment',
            )
        # If the right side is a variable
        else:
            subs = []
            right_token = LLLnode.from_list('_R',
                                            typ=right.typ,
                                            location=right.location)
            for typ, loc in zip(keyz, locations):
                subs.append(
                    make_setter(add_variable_offset(left_token, typ, pos=pos),
                                add_variable_offset(right_token, typ, pos=pos),
                                loc,
                                pos=pos))
            return LLLnode.from_list(
                ['with', '_L', left, ['with', '_R', right, ['seq'] + subs]],
                typ=None,
            )
    else:
        raise Exception("Invalid type for setters")
Esempio n. 8
0
def to_int128(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("int128", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("int128"),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("int128", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("int128"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif input_type in ("bytes32", "int256"):
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds("int128", in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType("int128"),
                                         pos=getpos(expr))
        else:
            # cast to output type so clamp_basetype works
            in_arg = LLLnode.from_list(in_arg, typ="int128")
            return LLLnode.from_list(
                clamp_basetype(in_arg),
                typ=BaseType("int128"),
                pos=getpos(expr),
            )

    # CMC 20211020: what is the purpose of this .. it lops off 32 bits
    elif input_type == "address":
        return LLLnode.from_list(
            ["signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]],
            typ=BaseType("int128"),
            pos=getpos(expr),
        )

    elif input_type in ("String", "Bytes"):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128",
                expr,
            )
        return byte_array_to_num(in_arg, "int128")

    elif input_type == "uint256":
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds("int128", in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType("int128"),
                                         pos=getpos(expr))

        # !! do not use clamp_basetype. check that 0 <= input <= MAX_INT128.
        res = int_clamp(in_arg, 127, signed=False)
        return LLLnode.from_list(
            res,
            typ="int128",
            pos=getpos(expr),
        )

    elif input_type == "decimal":
        # cast to int128 so clamp_basetype works
        res = LLLnode.from_list(["sdiv", in_arg, DECIMAL_DIVISOR],
                                typ="int128")
        return LLLnode.from_list(clamp_basetype(res),
                                 typ="int128",
                                 pos=getpos(expr))

    elif input_type in ("bool", "uint8"):
        # note: for int8, would need signextend
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int128"),
                                 pos=getpos(expr))

    else:
        raise InvalidLiteral(f"Invalid input for int128: {in_arg}", expr)
Esempio n. 9
0
def test_arithmetic_thorough(get_contract, assert_tx_failed,
                             assert_compile_failed, op, typ, lo, hi, bits):
    # both variables
    code_1 = f"""
@external
def foo(x: {typ}, y: {typ}) -> {typ}:
    return x {op} y
    """
    # right is literal
    code_2_template = """
@external
def foo(x: {typ}) -> {typ}:
    return x {op} {y}
    """
    # left is literal
    code_3_template = """
@external
def foo(y: {typ}) -> {typ}:
    return {x} {op} y
    """
    # both literals
    code_4_template = """
@external
def foo() -> {typ}:
    return {x} {op} {y}
    """

    fns = {
        "+": operator.add,
        "-": operator.sub,
        "*": operator.mul,
        "/": evm_div,
        "%": evm_mod
    }
    fn = fns[op]

    c = get_contract(code_1)

    # TODO refactor to use fixtures
    special_cases = [
        lo,
        lo + 1,
        lo // 2,
        lo // 2 - 1,
        lo // 2 + 1,
        -3,
        -2,
        -1,
        0,
        1,
        2,
        3,
        hi // 2 - 1,
        hi // 2,
        hi // 2 + 1,
        hi - 1,
        hi,
    ]
    xs = special_cases.copy()
    ys = special_cases.copy()

    # note: (including special cases, roughly 8k cases total generated)

    NUM_CASES = 5
    # poor man's fuzzing - hypothesis doesn't make it easy
    # with the parametrized strategy
    xs += [random.randrange(lo, hi) for _ in range(NUM_CASES)]
    ys += [random.randrange(lo, hi) for _ in range(NUM_CASES)]

    # edge cases that are tricky to reason about and MUST be tested
    assert lo in xs and -1 in ys

    for (x, y) in itertools.product(xs, ys):
        expected = fn(x, y)
        in_bounds = SizeLimits.in_bounds(typ, expected)

        # safediv and safemod disallow divisor == 0
        div_by_zero = y == 0 and op in ("/", "%")

        ok = in_bounds and not div_by_zero

        code_2 = code_2_template.format(typ=typ, op=op, y=y)
        code_3 = code_3_template.format(typ=typ, op=op, x=x)
        code_4 = code_4_template.format(typ=typ, op=op, x=x, y=y)

        if ok:
            assert c.foo(x, y) == expected
            assert get_contract(code_2).foo(x) == expected
            assert get_contract(code_3).foo(y) == expected
            assert get_contract(code_4).foo() == expected
        elif div_by_zero:
            assert_tx_failed(lambda: c.foo(x, y))
            assert_compile_failed(lambda: get_contract(code_2),
                                  ZeroDivisionException)
            assert_tx_failed(lambda: get_contract(code_3).foo(y))
            assert_compile_failed(lambda: get_contract(code_4),
                                  ZeroDivisionException)
        else:
            assert_tx_failed(lambda: c.foo(x, y))
            assert_tx_failed(lambda: get_contract(code_2).foo(x))
            assert_tx_failed(lambda: get_contract(code_3).foo(y))
            assert_compile_failed(lambda: get_contract(code_4),
                                  (InvalidType, OverflowException))
Esempio n. 10
0
    def parse_return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatchException("Not expecting to return a value",
                                            self.stmt)
            return LLLnode.from_list(make_return_stmt(self.stmt, self.context,
                                                      0, 0),
                                     typ=None,
                                     pos=getpos(self.stmt),
                                     valency=0)
        if not self.stmt.value:
            raise TypeMismatchException("Expecting to return a value",
                                        self.stmt)

        def zero_pad(bytez_placeholder, maxlen):
            zero_padder = LLLnode.from_list(['pass'])
            if maxlen > 0:
                zero_pad_i = self.context.new_placeholder(
                    BaseType('uint256'))  # Iterator used to zero pad memory.
                zero_padder = LLLnode.from_list(
                    [
                        'with',
                        '_ceil32_end',
                        ['ceil32', ['mload', bytez_placeholder]],
                        [
                            'repeat',
                            zero_pad_i,
                            ['mload', bytez_placeholder],
                            maxlen,
                            [
                                'seq',
                                [
                                    'if',
                                    [
                                        'gt', ['mload', zero_pad_i],
                                        '_ceil32_end'
                                    ], 'break'
                                ],  # stay within allocated bounds
                                [
                                    'mstore8',
                                    [
                                        'add', ['add', 32, bytez_placeholder],
                                        ['mload', zero_pad_i]
                                    ], 0
                                ]
                            ]
                        ]
                    ],
                    annotation="Zero pad")
            return zero_padder

        sub = Expr(self.stmt.value, self.context).lll_node
        self.context.increment_return_counter()
        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            sub = unwrap_location(sub)

            if not isinstance(self.context.return_type, BaseType):
                raise TypeMismatchException(
                    "Return type units mismatch %r %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            elif self.context.return_type != sub.typ and not sub.typ.is_literal:
                raise TypeMismatchException(
                    "Trying to return base type %r, output expecting %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            elif sub.typ.is_literal and (
                    self.context.return_type.typ == sub.typ or 'int'
                    in self.context.return_type.typ and 'int' in sub.typ.typ):
                if not SizeLimits.in_bounds(self.context.return_type.typ,
                                            sub.value):
                    raise InvalidLiteralException(
                        "Number out of range: " + str(sub.value), self.stmt)
                else:
                    return LLLnode.from_list([
                        'seq', ['mstore', 0, sub],
                        make_return_stmt(self.stmt, self.context, 0, 32)
                    ],
                                             typ=None,
                                             pos=getpos(self.stmt),
                                             valency=0)
            elif is_base_type(sub.typ, self.context.return_type.typ) or \
                    (is_base_type(sub.typ, 'int128') and is_base_type(self.context.return_type, 'int256')):
                return LLLnode.from_list([
                    'seq', ['mstore', 0, sub],
                    make_return_stmt(self.stmt, self.context, 0, 32)
                ],
                                         typ=None,
                                         pos=getpos(self.stmt),
                                         valency=0)
            else:
                raise TypeMismatchException(
                    "Unsupported type conversion: %r to %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayLike):
            if not sub.typ.eq_base(self.context.return_type):
                raise TypeMismatchException(
                    "Trying to return base type %r, output expecting %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatchException(
                    "Cannot cast from greater max-length %d to shorter max-length %d"
                    % (sub.typ.maxlen, self.context.return_type.maxlen),
                    self.stmt.value)

            loop_memory_position = self.context.new_placeholder(
                typ=BaseType(
                    'uint256'))  # loop memory has to be allocated first.
            len_placeholder = self.context.new_placeholder(
                typ=BaseType('uint256')
            )  # len & bytez placeholder have to be declared after each other at all times.
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ('storage', 'memory'):
                return LLLnode.from_list([
                    'seq',
                    make_byte_array_copier(LLLnode(
                        bytez_placeholder, location='memory', typ=sub.typ),
                                           sub,
                                           pos=getpos(self.stmt)),
                    zero_pad(bytez_placeholder, sub.typ.maxlen),
                    ['mstore', len_placeholder, 32],
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        len_placeholder,
                        ['ceil32', ['add', ['mload', bytez_placeholder], 64]],
                        loop_memory_position=loop_memory_position)
                ],
                                         typ=None,
                                         pos=getpos(self.stmt),
                                         valency=0)
            else:
                raise Exception("Invalid location: %s" % sub.location)

        elif isinstance(sub.typ, ListType):
            sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0]
            ret_base_type = re.split(r'\(|\[',
                                     str(self.context.return_type.subtype))[0]
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            if sub_base_type != ret_base_type:
                raise TypeMismatchException(
                    "List return type %r does not match specified return type, expecting %r"
                    % (sub_base_type, ret_base_type), self.stmt)
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(make_return_stmt(
                    self.stmt,
                    self.context,
                    sub,
                    get_size_of_type(self.context.return_type) * 32,
                    loop_memory_position=loop_memory_position),
                                         typ=None,
                                         pos=getpos(self.stmt),
                                         valency=0)
            else:
                new_sub = LLLnode.from_list(self.context.new_placeholder(
                    self.context.return_type),
                                            typ=self.context.return_type,
                                            location='memory')
                setter = make_setter(new_sub,
                                     sub,
                                     'memory',
                                     pos=getpos(self.stmt))
                return LLLnode.from_list([
                    'seq', setter,
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        new_sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position)
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))

        # Returning a struct
        elif isinstance(sub.typ, StructType):
            retty = self.context.return_type
            if not isinstance(retty, StructType) or retty.name != sub.typ.name:
                raise TypeMismatchException(
                    "Trying to return %r, output expecting %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            return gen_tuple_return(self.stmt, self.context, sub)

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if not isinstance(self.context.return_type, TupleType):
                raise TypeMismatchException(
                    "Trying to return tuple type %r, output expecting %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)

            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!",
                                         self.stmt)

            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member,
                                                  NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    raise StructureException(
                        "Tuple return type does not match annotated return. {} != {}"
                        .format(type(sub_type), type(ret_x)), self.stmt)
            return gen_tuple_return(self.stmt, self.context, sub)

        else:
            raise TypeMismatchException("Can't return type %r" % sub.typ,
                                        self.stmt)
Esempio n. 11
0
    def parse_return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatchException("Not expecting to return a value", self.stmt)
            return LLLnode.from_list(self.make_return_stmt(0, 0), typ=None, pos=getpos(self.stmt))
        if not self.stmt.value:
            raise TypeMismatchException("Expecting to return a value", self.stmt)

        def zero_pad(bytez_placeholder, maxlen):
            zero_padder = LLLnode.from_list(['pass'])
            if maxlen > 0:
                zero_pad_i = self.context.new_placeholder(BaseType('uint256'))  # Iterator used to zero pad memory.
                zero_padder = LLLnode.from_list(
                    ['repeat', zero_pad_i, ['mload', bytez_placeholder], maxlen,
                        ['seq',
                            ['if', ['gt', ['mload', zero_pad_i], maxlen], 'break'],  # stay within allocated bounds
                            ['mstore8', ['add', ['add', 32, bytez_placeholder], ['mload', zero_pad_i]], 0]]],
                    annotation="Zero pad"
                )
            return zero_padder

        sub = Expr(self.stmt.value, self.context).lll_node
        self.context.increment_return_counter()
        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            if not isinstance(self.context.return_type, BaseType):
                raise TypeMismatchException("Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value)
            sub = unwrap_location(sub)
            if not are_units_compatible(sub.typ, self.context.return_type):
                raise TypeMismatchException("Return type units mismatch %r %r" % (sub.typ, self.context.return_type), self.stmt.value)
            elif sub.typ.is_literal and (self.context.return_type.typ == sub.typ or
                    'int' in self.context.return_type.typ and
                    'int' in sub.typ.typ):
                if not SizeLimits.in_bounds(self.context.return_type.typ, sub.value):
                    raise InvalidLiteralException("Number out of range: " + str(sub.value), self.stmt)
                else:
                    return LLLnode.from_list(['seq', ['mstore', 0, sub], self.make_return_stmt(0, 32)], typ=None, pos=getpos(self.stmt))
            elif is_base_type(sub.typ, self.context.return_type.typ) or \
                    (is_base_type(sub.typ, 'int128') and is_base_type(self.context.return_type, 'int256')):
                return LLLnode.from_list(['seq', ['mstore', 0, sub], self.make_return_stmt(0, 32)], typ=None, pos=getpos(self.stmt))
            else:
                raise TypeMismatchException("Unsupported type conversion: %r to %r" % (sub.typ, self.context.return_type), self.stmt.value)
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayType):
            if not isinstance(self.context.return_type, ByteArrayType):
                raise TypeMismatchException("Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value)
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatchException("Cannot cast from greater max-length %d to shorter max-length %d" %
                                            (sub.typ.maxlen, self.context.return_type.maxlen), self.stmt.value)

            loop_memory_position = self.context.new_placeholder(typ=BaseType('uint256'))  # loop memory has to be allocated first.
            len_placeholder = self.context.new_placeholder(typ=BaseType('uint256'))  # len & bytez placeholder have to be declared after each other at all times.
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ('storage', 'memory'):
                return LLLnode.from_list([
                    'seq',
                    make_byte_array_copier(
                        LLLnode(bytez_placeholder, location='memory', typ=sub.typ),
                        sub,
                        pos=getpos(self.stmt)
                    ),
                    zero_pad(bytez_placeholder, sub.typ.maxlen),
                    ['mstore', len_placeholder, 32],
                    self.make_return_stmt(len_placeholder, ['ceil32', ['add', ['mload', bytez_placeholder], 64]], loop_memory_position=loop_memory_position)],
                    typ=None, pos=getpos(self.stmt)
                )
            else:
                raise Exception("Invalid location: %s" % sub.location)

        elif isinstance(sub.typ, ListType):
            sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0]
            ret_base_type = re.split(r'\(|\[', str(self.context.return_type.subtype))[0]
            loop_memory_position = self.context.new_placeholder(typ=BaseType('uint256'))
            if sub_base_type != ret_base_type:
                raise TypeMismatchException(
                    "List return type %r does not match specified return type, expecting %r" % (
                        sub_base_type, ret_base_type
                    ),
                    self.stmt
                )
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(self.make_return_stmt(sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position),
                                            typ=None, pos=getpos(self.stmt))
            else:
                new_sub = LLLnode.from_list(self.context.new_placeholder(self.context.return_type), typ=self.context.return_type, location='memory')
                setter = make_setter(new_sub, sub, 'memory', pos=getpos(self.stmt))
                return LLLnode.from_list(['seq', setter, self.make_return_stmt(new_sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position)],
                                            typ=None, pos=getpos(self.stmt))

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!", self.stmt)
            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member, NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    raise StructureException(
                        "Tuple return type does not match annotated return. {} != {}".format(
                            type(sub_type), type(ret_x)
                        ),
                        self.stmt
                    )

            # Is from a call expression.
            if len(sub.args[0].args) > 0 and sub.args[0].args[0].value == 'call':  # self-call to public.
                mem_pos = sub.args[0].args[-1]
                mem_size = get_size_of_type(sub.typ) * 32
                return LLLnode.from_list(['return', mem_pos, mem_size], typ=sub.typ)

            elif (sub.annotation and 'Internal Call' in sub.annotation):
                mem_pos = sub.args[-1].value if sub.value == 'seq_unchecked' else sub.args[0].args[-1]
                mem_size = get_size_of_type(sub.typ) * 32
                # Add zero padder if bytes are present in output.
                zero_padder = ['pass']
                byte_arrays = [(i, x) for i, x in enumerate(sub.typ.members) if isinstance(x, ByteArrayType)]
                if byte_arrays:
                    i, x = byte_arrays[-1]
                    zero_padder = zero_pad(bytez_placeholder=['add', mem_pos, ['mload', mem_pos + i * 32]], maxlen=x.maxlen)
                return LLLnode.from_list(
                    ['seq'] +
                    [sub] +
                    [zero_padder] +
                    [self.make_return_stmt(mem_pos, mem_size)
                ], typ=sub.typ, pos=getpos(self.stmt))

            subs = []
            # Pre-allocate loop_memory_position if required for private function returning.
            loop_memory_position = self.context.new_placeholder(typ=BaseType('uint256')) if self.context.is_private else None
            # Allocate dynamic off set counter, to keep track of the total packed dynamic data size.
            dynamic_offset_counter_placeholder = self.context.new_placeholder(typ=BaseType('uint256'))
            dynamic_offset_counter = LLLnode(
                dynamic_offset_counter_placeholder, typ=None, annotation="dynamic_offset_counter"  # dynamic offset position counter.
            )
            new_sub = LLLnode.from_list(
                self.context.new_placeholder(typ=BaseType('uint256')), typ=self.context.return_type, location='memory', annotation='new_sub'
            )
            keyz = list(range(len(sub.typ.members)))
            dynamic_offset_start = 32 * len(sub.args)  # The static list of args end.
            left_token = LLLnode.from_list('_loc', typ=new_sub.typ, location="memory")

            def get_dynamic_offset_value():
                # Get value of dynamic offset counter.
                return ['mload', dynamic_offset_counter]

            def increment_dynamic_offset(dynamic_spot):
                # Increment dyanmic offset counter in memory.
                return [
                    'mstore', dynamic_offset_counter,
                    ['add',
                        ['add', ['ceil32', ['mload', dynamic_spot]], 32],
                        ['mload', dynamic_offset_counter]]
                ]

            for i, typ in enumerate(keyz):
                arg = sub.args[i]
                variable_offset = LLLnode.from_list(['add', 32 * i, left_token], typ=arg.typ, annotation='variable_offset')
                if isinstance(arg.typ, ByteArrayType):
                    # Store offset pointer value.
                    subs.append(['mstore', variable_offset, get_dynamic_offset_value()])

                    # Store dynamic data, from offset pointer onwards.
                    dynamic_spot = LLLnode.from_list(['add', left_token, get_dynamic_offset_value()], location="memory", typ=arg.typ, annotation='dynamic_spot')
                    subs.append(make_setter(dynamic_spot, arg, location="memory", pos=getpos(self.stmt)))
                    subs.append(increment_dynamic_offset(dynamic_spot))

                elif isinstance(arg.typ, BaseType):
                    subs.append(make_setter(variable_offset, arg, "memory", pos=getpos(self.stmt)))
                else:
                    raise Exception("Can't return type %s as part of tuple", type(arg.typ))

            setter = LLLnode.from_list(
                ['seq',
                    ['mstore', dynamic_offset_counter, dynamic_offset_start],
                    ['with', '_loc', new_sub, ['seq'] + subs]],
                typ=None
            )

            return LLLnode.from_list(
                ['seq',
                    setter,
                    self.make_return_stmt(new_sub, get_dynamic_offset_value(), loop_memory_position)],
                typ=None, pos=getpos(self.stmt)
            )
        else:
            raise TypeMismatchException("Can only return base type!", self.stmt)
Esempio n. 12
0
    def parse_Compare(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if right.value is None:
            return

        if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike):
            # TODO: Can this if branch be removed ^
            pass

        elif isinstance(self.expr.op, vy_ast.In) and isinstance(right.typ, ListType):
            if left.typ != right.typ.subtype:
                return
            return self.build_in_comparator()

        if isinstance(self.expr.op, vy_ast.Gt):
            op = "sgt"
        elif isinstance(self.expr.op, vy_ast.GtE):
            op = "sge"
        elif isinstance(self.expr.op, vy_ast.LtE):
            op = "sle"
        elif isinstance(self.expr.op, vy_ast.Lt):
            op = "slt"
        elif isinstance(self.expr.op, vy_ast.Eq):
            op = "eq"
        elif isinstance(self.expr.op, vy_ast.NotEq):
            op = "ne"
        else:
            return

        # Compare (limited to 32) byte arrays.
        if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike):
            left = Expr(self.expr.left, self.context).lll_node
            right = Expr(self.expr.right, self.context).lll_node

            length_mismatch = left.typ.maxlen != right.typ.maxlen
            left_over_32 = left.typ.maxlen > 32
            right_over_32 = right.typ.maxlen > 32
            if length_mismatch or left_over_32 or right_over_32:
                left_keccak = keccak256_helper(self.expr, [left], None, self.context)
                right_keccak = keccak256_helper(self.expr, [right], None, self.context)

                if op == "eq" or op == "ne":
                    return LLLnode.from_list(
                        [op, left_keccak, right_keccak], typ="bool", pos=getpos(self.expr),
                    )

                else:
                    return

            else:

                def load_bytearray(side):
                    if side.location == "memory":
                        return ["mload", ["add", 32, side]]
                    elif side.location == "storage":
                        return ["sload", ["add", 1, ["sha3_32", side]]]

                return LLLnode.from_list(
                    [op, load_bytearray(left), load_bytearray(right)],
                    typ="bool",
                    pos=getpos(self.expr),
                )

        # Compare other types.
        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            if op not in ("eq", "ne"):
                return
        left_type, right_type = left.typ.typ, right.typ.typ

        # Special Case: comparison of a literal integer. If in valid range allow it to be compared.
        if {left_type, right_type} == {"int128", "uint256"} and {
            left.typ.is_literal,
            right.typ.is_literal,
        } == {
            True,
            False,
        }:  # noqa: E501

            comparison_allowed = False
            if left.typ.is_literal and SizeLimits.in_bounds(right_type, left.value):
                comparison_allowed = True
            elif right.typ.is_literal and SizeLimits.in_bounds(left_type, right.value):
                comparison_allowed = True
            op = self._signed_to_unsigned_comparision_op(op)

            if comparison_allowed:
                return LLLnode.from_list([op, left, right], typ="bool", pos=getpos(self.expr))

        elif {left_type, right_type} == {"uint256", "uint256"}:
            op = self._signed_to_unsigned_comparision_op(op)
        elif (
            left_type in ("decimal", "int128") or right_type in ("decimal", "int128")
        ) and left_type != right_type:  # noqa: E501
            return

        if left_type == right_type:
            return LLLnode.from_list([op, left, right], typ="bool", pos=getpos(self.expr))
Esempio n. 13
0
    def parse_BinOp(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            return

        types = {left.typ.typ, right.typ.typ}
        literals = {left.typ.is_literal, right.typ.is_literal}

        # If one value of the operation is a literal, we recast it to match the non-literal type.
        # We know this is OK because types were already verified in the actual typechecking pass.
        # This is a temporary solution to not break codegen while we work toward removing types
        # altogether at this stage of complition. @iamdefinitelyahuman
        if literals == {True, False
                        } and len(types) > 1 and "decimal" not in types:
            if left.typ.is_literal and SizeLimits.in_bounds(
                    right.typ.typ, left.value):
                left = IRnode.from_list(left.value,
                                        typ=BaseType(right.typ.typ,
                                                     is_literal=True))
            elif right.typ.is_literal and SizeLimits.in_bounds(
                    left.typ.typ, right.value):
                right = IRnode.from_list(right.value,
                                         typ=BaseType(left.typ.typ,
                                                      is_literal=True))

        ltyp, rtyp = left.typ.typ, right.typ.typ

        # Sanity check - ensure that we aren't dealing with different types
        # This should be unreachable due to the type check pass
        assert ltyp == rtyp, "unreachable"

        arith = None
        if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)):
            new_typ = BaseType(ltyp)

            if ltyp == "uint256":
                if isinstance(self.expr.op, vy_ast.Add):
                    # safeadd
                    arith = [
                        "seq", ["assert", ["ge", ["add", "l", "r"], "l"]],
                        ["add", "l", "r"]
                    ]

                elif isinstance(self.expr.op, vy_ast.Sub):
                    # safesub
                    arith = [
                        "seq", ["assert", ["ge", "l", "r"]], ["sub", "l", "r"]
                    ]

            elif ltyp == "int256":
                if isinstance(self.expr.op, vy_ast.Add):
                    op, comp1, comp2 = "add", "sge", "slt"
                else:
                    op, comp1, comp2 = "sub", "sle", "sgt"

                if right.typ.is_literal:
                    if right.value >= 0:
                        arith = [
                            "seq", ["assert", [comp1, [op, "l", "r"], "l"]],
                            [op, "l", "r"]
                        ]
                    else:
                        arith = [
                            "seq", ["assert", [comp2, [op, "l", "r"], "l"]],
                            [op, "l", "r"]
                        ]
                else:
                    arith = [
                        "with",
                        "ans",
                        [op, "l", "r"],
                        [
                            "seq",
                            [
                                "assert",
                                [
                                    "or",
                                    [
                                        "and", ["sge", "r", 0],
                                        [comp1, "ans", "l"]
                                    ],
                                    [
                                        "and", ["slt", "r", 0],
                                        [comp2, "ans", "l"]
                                    ],
                                ],
                            ],
                            "ans",
                        ],
                    ]

            elif ltyp in ("decimal", "int128", "uint8"):
                op = "add" if isinstance(self.expr.op, vy_ast.Add) else "sub"
                arith = [op, "l", "r"]

        elif isinstance(self.expr.op, vy_ast.Mult):
            new_typ = BaseType(ltyp)
            if ltyp == "uint256":
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        [
                            "assert",
                            [
                                "or", ["eq", ["div", "ans", "l"], "r"],
                                ["iszero", "l"]
                            ]
                        ],
                        "ans",
                    ],
                ]

            elif ltyp == "int256":
                if version_check(begin="constantinople"):
                    upper_bound = ["shl", 255, 1]
                else:
                    upper_bound = -(2**255)
                if not left.typ.is_literal and not right.typ.is_literal:
                    bounds_check = [
                        "assert",
                        [
                            "or", ["ne", "l", ["not", 0]],
                            ["ne", "r", upper_bound]
                        ],
                    ]
                elif left.typ.is_literal and left.value == -1:
                    bounds_check = ["assert", ["ne", "r", upper_bound]]
                elif right.typ.is_literal and right.value == -(2**255):
                    bounds_check = ["assert", ["ne", "l", ["not", 0]]]
                else:
                    bounds_check = "pass"
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        bounds_check,
                        [
                            "assert",
                            [
                                "or", ["eq", ["sdiv", "ans", "l"], "r"],
                                ["iszero", "l"]
                            ]
                        ],
                        "ans",
                    ],
                ]

            elif ltyp in ("int128", "uint8"):
                arith = ["mul", "l", "r"]

            elif ltyp == "decimal":
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        [
                            "assert",
                            [
                                "or", ["eq", ["sdiv", "ans", "l"], "r"],
                                ["iszero", "l"]
                            ]
                        ],
                        ["sdiv", "ans", DECIMAL_DIVISOR],
                    ],
                ]

        elif isinstance(self.expr.op, vy_ast.Div):
            if right.typ.is_literal and right.value == 0:
                return

            new_typ = BaseType(ltyp)

            if right.typ.is_literal:
                divisor = "r"
            else:
                # only apply the non-zero clamp when r is not a constant
                divisor = ["clamp_nonzero", "r"]

            if ltyp in ("uint8", "uint256"):
                arith = ["div", "l", divisor]

            elif ltyp == "int256":
                if version_check(begin="constantinople"):
                    upper_bound = ["shl", 255, 1]
                else:
                    upper_bound = -(2**255)
                if not left.typ.is_literal and not right.typ.is_literal:
                    bounds_check = [
                        "assert",
                        [
                            "or", ["ne", "r", ["not", 0]],
                            ["ne", "l", upper_bound]
                        ],
                    ]
                elif left.typ.is_literal and left.value == -(2**255):
                    bounds_check = ["assert", ["ne", "r", ["not", 0]]]
                elif right.typ.is_literal and right.value == -1:
                    bounds_check = ["assert", ["ne", "l", upper_bound]]
                else:
                    bounds_check = "pass"
                arith = ["seq", bounds_check, ["sdiv", "l", divisor]]

            elif ltyp == "int128":
                arith = ["sdiv", "l", divisor]

            elif ltyp == "decimal":
                arith = [
                    "sdiv",
                    ["mul", "l", DECIMAL_DIVISOR],
                    divisor,
                ]

        elif isinstance(self.expr.op, vy_ast.Mod):
            if right.typ.is_literal and right.value == 0:
                return

            new_typ = BaseType(ltyp)

            if right.typ.is_literal:
                divisor = "r"
            else:
                # only apply the non-zero clamp when r is not a constant
                divisor = ["clamp_nonzero", "r"]

            if ltyp in ("uint8", "uint256"):
                arith = ["mod", "l", divisor]
            else:
                arith = ["smod", "l", divisor]

        elif isinstance(self.expr.op, vy_ast.Pow):
            new_typ = BaseType(ltyp)

            # TODO optimizer rule for special cases
            if self.expr.left.get("value") == 1:
                return IRnode.from_list([1], typ=new_typ)
            if self.expr.left.get("value") == 0:
                return IRnode.from_list(["iszero", right], typ=new_typ)

            if ltyp == "int128":
                is_signed = True
                num_bits = 128
            elif ltyp == "int256":
                is_signed = True
                num_bits = 256
            elif ltyp == "uint8":
                is_signed = False
                num_bits = 8
            else:
                is_signed = False
                num_bits = 256

            if isinstance(self.expr.left, vy_ast.Int):
                value = self.expr.left.value
                upper_bound = calculate_largest_power(value, num_bits,
                                                      is_signed) + 1
                # for signed integers, this also prevents negative values
                clamp = ["lt", right, upper_bound]
                return IRnode.from_list(
                    ["seq", ["assert", clamp], ["exp", left, right]],
                    typ=new_typ,
                )
            elif isinstance(self.expr.right, vy_ast.Int):
                value = self.expr.right.value
                upper_bound = calculate_largest_base(value, num_bits,
                                                     is_signed) + 1
                if is_signed:
                    clamp = [
                        "and", ["slt", left, upper_bound],
                        ["sgt", left, -upper_bound]
                    ]
                else:
                    clamp = ["lt", left, upper_bound]
                return IRnode.from_list(
                    ["seq", ["assert", clamp], ["exp", left, right]],
                    typ=new_typ)
            else:
                # `a ** b` where neither `a` or `b` are known
                # TODO this is currently unreachable, once we implement a way to do it safely
                # remove the check in `vyper/context/types/value/numeric.py`
                return

        if arith is None:
            return

        arith = IRnode.from_list(arith, typ=new_typ)

        p = [
            "with",
            "l",
            left,
            [
                "with",
                "r",
                right,
                # note clamp_basetype is a noop on [u]int256
                # note: clamp_basetype throws on unclampable input
                clamp_basetype(arith),
            ],
        ]
        return IRnode.from_list(p, typ=new_typ)
Esempio n. 14
0
def test_arithmetic_thorough(
    get_contract, assert_tx_failed, assert_compile_failed, op, typ, lo, hi, bits
):
    # both variables
    code_1 = f"""
@external
def foo(x: {typ}, y: {typ}) -> {typ}:
    return x {op} y
    """
    # right is literal
    code_2_template = """
@external
def foo(x: {typ}) -> {typ}:
    return x {op} {y}
    """
    # left is literal
    code_3_template = """
@external
def foo(y: {typ}) -> {typ}:
    return {x} {op} y
    """
    # both literals
    code_4_template = """
@external
def foo() -> {typ}:
    return {x} {op} {y}
    """

    c = get_contract(code_1)

    fn = ARITHMETIC_OPS[op]

    special_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 2, hi - 1, hi]
    xs = special_cases.copy()
    ys = special_cases.copy()
    NUM_CASES = 5
    # poor man's fuzzing - hypothesis doesn't make it easy
    # with the parametrized strategy
    xs += [random.randrange(lo, hi) for _ in range(NUM_CASES)]
    ys += [random.randrange(lo, hi) for _ in range(NUM_CASES)]

    # mirror signed integer tests
    assert 2 ** (bits - 1) in xs and (2 ** bits) - 1 in ys

    for (x, y) in itertools.product(xs, ys):
        expected = fn(x, y)
        in_bounds = SizeLimits.in_bounds(typ, expected)
        # safediv and safemod disallow divisor == 0
        div_by_zero = y == 0 and op in ("/", "%")

        ok = in_bounds and not div_by_zero

        code_2 = code_2_template.format(typ=typ, op=op, y=y)
        code_3 = code_3_template.format(typ=typ, op=op, x=x)
        code_4 = code_4_template.format(typ=typ, op=op, x=x, y=y)

        if ok:
            assert c.foo(x, y) == expected
            assert get_contract(code_2).foo(x) == expected
            assert get_contract(code_3).foo(y) == expected
            assert get_contract(code_4).foo() == expected
        elif div_by_zero:
            assert_tx_failed(lambda: c.foo(x, y))
            assert_compile_failed(lambda: get_contract(code_2), ZeroDivisionException)
            assert_tx_failed(lambda: get_contract(code_3).foo(y))
            assert_compile_failed(lambda: get_contract(code_4), ZeroDivisionException)
        else:
            assert_tx_failed(lambda: c.foo(x, y))
            assert_tx_failed(lambda: get_contract(code_2).foo(x))
            assert_tx_failed(lambda: get_contract(code_3).foo(y))
            assert_compile_failed(lambda: get_contract(code_4), (InvalidType, OverflowException))
Esempio n. 15
0
def to_decimal(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "Bytes":
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to decimal",
                expr,
            )
        num = byte_array_to_num(in_arg, expr, "int128")
        return LLLnode.from_list(["mul", num, DECIMAL_DIVISOR],
                                 typ=BaseType("decimal"),
                                 pos=getpos(expr))

    else:
        if input_type == "uint256":
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds("int128",
                                            (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteral(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR],
                                             typ=BaseType("decimal"),
                                             pos=getpos(expr))
            else:
                return LLLnode.from_list(
                    [
                        "uclample",
                        ["mul", in_arg, DECIMAL_DIVISOR],
                        ["mload", MemoryPositions.MAXDECIMAL],
                    ],
                    typ=BaseType("decimal"),
                    pos=getpos(expr),
                )

        elif input_type == "address":
            return LLLnode.from_list(
                [
                    "mul",
                    [
                        "signextend", 15,
                        ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]
                    ],
                    DECIMAL_DIVISOR,
                ],
                typ=BaseType("decimal"),
                pos=getpos(expr),
            )

        elif input_type == "bytes32":
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds("int128",
                                            (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteral(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR],
                                             typ=BaseType("decimal"),
                                             pos=getpos(expr))
            else:
                return LLLnode.from_list(
                    [
                        "clamp",
                        ["mload", MemoryPositions.MINDECIMAL],
                        ["mul", in_arg, DECIMAL_DIVISOR],
                        ["mload", MemoryPositions.MAXDECIMAL],
                    ],
                    typ=BaseType("decimal"),
                    pos=getpos(expr),
                )

        elif input_type == "int256":
            return LLLnode.from_list(
                [
                    "seq",
                    int128_clamp(in_arg), ["mul", in_arg, DECIMAL_DIVISOR]
                ],
                typ=BaseType("decimal"),
                pos=getpos(expr),
            )

        elif input_type in ("int128", "bool"):
            return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR],
                                     typ=BaseType("decimal"),
                                     pos=getpos(expr))

        else:
            raise InvalidLiteral(f"Invalid input for decimal: {in_arg}", expr)
Esempio n. 16
0
File: expr.py Progetto: 6pakla/vyper
    def arithmetic(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatch(
                f"Unsupported types for arithmetic op: {left.typ} {right.typ}",
                self.expr,
            )

        arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Special Case: Simplify any literal to literal arithmetic at compile time.
        if left.typ.is_literal and right.typ.is_literal and \
           isinstance(right.value, int) and isinstance(left.value, int) and \
           arithmetic_pair.issubset({'uint256', 'int128'}):

            if isinstance(self.expr.op, vy_ast.Add):
                val = left.value + right.value
            elif isinstance(self.expr.op, vy_ast.Sub):
                val = left.value - right.value
            elif isinstance(self.expr.op, vy_ast.Mult):
                val = left.value * right.value
            elif isinstance(self.expr.op, vy_ast.Pow):
                val = left.value ** right.value
            elif isinstance(self.expr.op, (vy_ast.Div, vy_ast.Mod)):
                if right.value == 0:
                    raise ZeroDivisionException(
                        "integer division or modulo by zero",
                        self.expr,
                    )
                if isinstance(self.expr.op, vy_ast.Div):
                    val = left.value // right.value
                elif isinstance(self.expr.op, vy_ast.Mod):
                    # modified modulo logic to remain consistent with EVM behaviour
                    val = abs(left.value) % abs(right.value)
                    if left.value < 0:
                        val = -val
            else:
                raise StructureException(
                    f'Unsupported literal operator: {type(self.expr.op)}',
                    self.expr,
                )

            num = vy_ast.Int(value=val)
            num.full_source_code = self.expr.full_source_code
            num.node_source_code = self.expr.node_source_code
            num.lineno = self.expr.lineno
            num.col_offset = self.expr.col_offset
            num.end_lineno = self.expr.end_lineno
            num.end_col_offset = self.expr.end_col_offset

            return Expr.parse_value_expr(num, self.context)

        pos = getpos(self.expr)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value):
                right = LLLnode.from_list(
                    right.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=pos,
                )

            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value):
                left = LLLnode.from_list(
                    left.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=pos,
                )

        if left.typ.typ == "decimal" and isinstance(self.expr.op, vy_ast.Pow):
            raise TypeMismatch(
                "Cannot perform exponentiation on decimal values.",
                self.expr,
            )

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatch(
                f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.",
                self.expr,
            )

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)):
            new_typ = BaseType(ltyp)
            op = 'add' if isinstance(self.expr.op, vy_ast.Add) else 'sub'

            if ltyp == 'uint256' and isinstance(self.expr.op, vy_ast.Add):
                # safeadd
                arith = ['seq',
                         ['assert', ['ge', ['add', 'l', 'r'], 'l']],
                         ['add', 'l', 'r']]

            elif ltyp == 'uint256' and isinstance(self.expr.op, vy_ast.Sub):
                # safesub
                arith = ['seq',
                         ['assert', ['ge', 'l', 'r']],
                         ['sub', 'l', 'r']]

            elif ltyp == rtyp:
                arith = [op, 'l', 'r']

            else:
                raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, vy_ast.Mult):
            new_typ = BaseType(ltyp)
            if ltyp == rtyp == 'uint256':
                arith = ['with', 'ans', ['mul', 'l', 'r'],
                         ['seq',
                             ['assert',
                                 ['or',
                                     ['eq', ['div', 'ans', 'l'], 'r'],
                                     ['iszero', 'l']]],
                             'ans']]

            elif ltyp == rtyp == 'int128':
                # TODO should this be 'smul' (note edge cases in YP for smul)
                arith = ['mul', 'l', 'r']

            elif ltyp == rtyp == 'decimal':
                # TODO should this be smul
                arith = ['with', 'ans', ['mul', 'l', 'r'],
                         ['seq',
                             ['assert',
                                 ['or',
                                     ['eq', ['sdiv', 'ans', 'l'], 'r'],
                                     ['iszero', 'l']]],
                             ['sdiv', 'ans', DECIMAL_DIVISOR]]]
            else:
                raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, vy_ast.Div):
            if right.typ.is_literal and right.value == 0:
                raise ZeroDivisionException("Cannot divide by 0.", self.expr)

            new_typ = BaseType(ltyp)
            if ltyp == rtyp == 'uint256':
                arith = ['div', 'l', ['clamp_nonzero', 'r']]

            elif ltyp == rtyp == 'int128':
                arith = ['sdiv', 'l', ['clamp_nonzero', 'r']]

            elif ltyp == rtyp == 'decimal':
                arith = ['sdiv',
                         # TODO check overflow cases, also should it be smul
                         ['mul', 'l', DECIMAL_DIVISOR],
                         ['clamp_nonzero', 'r']]

            else:
                raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, vy_ast.Mod):
            if right.typ.is_literal and right.value == 0:
                raise ZeroDivisionException("Cannot calculate modulus of 0.", self.expr)

            new_typ = BaseType(ltyp)

            if ltyp == rtyp == 'uint256':
                arith = ['mod', 'l', ['clamp_nonzero', 'r']]
            elif ltyp == rtyp:
                # TODO should this be regular mod
                arith = ['smod', 'l', ['clamp_nonzero', 'r']]

            else:
                raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, vy_ast.Pow):
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, vy_ast.Name):
                raise TypeMismatch(
                    "Cannot use dynamic values as exponents, for unit base types",
                    self.expr,
                )
            new_typ = BaseType(ltyp)

            if ltyp == rtyp == 'uint256':
                arith = ['seq',
                         ['assert',
                             ['or',
                                 # r == 1 | iszero(r)
                                 # could be simplified to ~(r & 1)
                                 ['or', ['eq', 'r', 1], ['iszero', 'r']],
                                 ['lt', 'l', ['exp', 'l', 'r']]]],
                         ['exp', 'l', 'r']]
            elif ltyp == rtyp == 'int128':
                arith = ['exp', 'l', 'r']

            else:
                raise TypeMismatch('Only whole number exponents are supported', self.expr)
        else:
            raise StructureException(f"Unsupported binary operator: {self.expr.op}", self.expr)

        p = ['seq']

        if new_typ.typ == 'int128':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                arith,
                ['mload', MemoryPositions.MAXNUM],
            ])
        elif new_typ.typ == 'decimal':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINDECIMAL],
                arith,
                ['mload', MemoryPositions.MAXDECIMAL],
            ])
        elif new_typ.typ == 'uint256':
            p.append(arith)
        else:
            raise Exception(f"{arith} {new_typ}")

        p = ['with', 'l', left, ['with', 'r', right, p]]
        return LLLnode.from_list(p, typ=new_typ, pos=pos)
Esempio n. 17
0
def to_int128(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("int128", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("int128"),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("int128", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("int128"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif input_type in ("bytes32", "int256"):
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds("int128", in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType("int128"),
                                         pos=getpos(expr))
        else:
            return LLLnode.from_list(
                int128_clamp(in_arg),
                typ=BaseType("int128"),
                pos=getpos(expr),
            )

    elif input_type == "address":
        return LLLnode.from_list(
            ["signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]],
            typ=BaseType("int128"),
            pos=getpos(expr),
        )

    elif input_type in ("String", "Bytes"):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128",
                expr,
            )
        return byte_array_to_num(in_arg, expr, "int128")

    elif input_type == "uint256":
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds("int128", in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType("int128"),
                                         pos=getpos(expr))

        else:
            return LLLnode.from_list(
                ["uclample", in_arg, ["mload", MemoryPositions.MAX_INT128]],
                typ=BaseType("int128"),
                pos=getpos(expr),
            )

    elif input_type == "decimal":
        return LLLnode.from_list(
            int128_clamp(["sdiv", in_arg, DECIMAL_DIVISOR]),
            typ=BaseType("int128"),
            pos=getpos(expr),
        )

    elif input_type == "bool":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int128"),
                                 pos=getpos(expr))

    else:
        raise InvalidLiteral(f"Invalid input for int128: {in_arg}", expr)
Esempio n. 18
0
    def arithmetic(self):
        pre_alloc_left, left = self.arithmetic_get_reference(self.expr.left)
        pre_alloc_right, right = self.arithmetic_get_reference(self.expr.right)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatchException(
                f"Unsupported types for arithmetic op: {left.typ} {right.typ}",
                self.expr,
            )

        arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Special Case: Simplify any literal to literal arithmetic at compile time.
        if left.typ.is_literal and right.typ.is_literal and \
           isinstance(right.value, int) and isinstance(left.value, int) and \
           arithmetic_pair.issubset({'uint256', 'int128'}):

            if isinstance(self.expr.op, ast.Add):
                val = left.value + right.value
            elif isinstance(self.expr.op, ast.Sub):
                val = left.value - right.value
            elif isinstance(self.expr.op, ast.Mult):
                val = left.value * right.value
            elif isinstance(self.expr.op, ast.Div):
                val = left.value // right.value
            elif isinstance(self.expr.op, ast.Mod):
                val = left.value % right.value
            elif isinstance(self.expr.op, ast.Pow):
                val = left.value ** right.value
            else:
                raise ParserException(
                    f'Unsupported literal operator: {str(type(self.expr.op))}',
                    self.expr,
                )

            num = ast.Num(n=val)
            num.source_code = self.expr.source_code
            num.lineno = self.expr.lineno
            num.col_offset = self.expr.col_offset
            num.end_lineno = self.expr.end_lineno
            num.end_col_offset = self.expr.end_col_offset

            return Expr.parse_value_expr(num, self.context)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value):
                right = LLLnode.from_list(
                    right.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=getpos(self.expr),
                )

            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value):
                left = LLLnode.from_list(
                    left.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=getpos(self.expr),
                )

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatchException(
                f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.",
                self.expr,
            )

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (ast.Add, ast.Sub)):
            if left.typ.unit != right.typ.unit and left.typ.unit != {} and right.typ.unit != {}:
                raise TypeMismatchException(
                    f"Unit mismatch: {left.typ.unit} {right.typ.unit}",
                    self.expr,
                )
            if left.typ.positional and right.typ.positional and isinstance(self.expr.op, ast.Add):
                raise TypeMismatchException(
                    "Cannot add two positional units!",
                    self.expr,
                )
            new_unit = left.typ.unit or right.typ.unit

            # xor, as subtracting two positionals gives a delta
            new_positional = left.typ.positional ^ right.typ.positional

            op = 'add' if isinstance(self.expr.op, ast.Add) else 'sub'
            if ltyp == 'uint256' and isinstance(self.expr.op, ast.Add):
                o = LLLnode.from_list([
                    'seq',
                    # Checks that: a + b >= a
                    ['assert', ['ge', ['add', left, right], left]],
                    ['add', left, right],
                ], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr))
            elif ltyp == 'uint256' and isinstance(self.expr.op, ast.Sub):
                o = LLLnode.from_list([
                    'seq',
                    # Checks that: a >= b
                    ['assert', ['ge', left, right]],
                    ['sub', left, right]
                ], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list(
                    [op, left, right],
                    typ=BaseType(ltyp, new_unit, new_positional),
                    pos=getpos(self.expr),
                )
            else:
                raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Mult):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot multiply positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'if',
                    ['eq', left, 0],
                    [0],
                    [
                        'seq', ['assert', ['eq', ['div', ['mul', left, right], left], right]],
                        ['mul', left, right]
                    ],
                ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(
                    ['mul', left, right],
                    typ=BaseType('int128', new_unit),
                    pos=getpos(self.expr),
                )
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list([
                    'with', 'r', right, [
                        'with', 'l', left, [
                            'with', 'ans', ['mul', 'l', 'r'],
                            [
                                'seq',
                                [
                                    'assert',
                                    ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['iszero', 'l']]
                                ],
                                ['sdiv', 'ans', DECIMAL_DIVISOR],
                            ],
                        ],
                    ],
                ], typ=BaseType('decimal', new_unit), pos=getpos(self.expr))
            else:
                raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Div):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot divide positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit, div=True)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    # Checks that:  b != 0
                    ['assert', right],
                    ['div', left, right],
                ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(
                    ['sdiv', left, ['clamp_nonzero', right]],
                    typ=BaseType('int128', new_unit),
                    pos=getpos(self.expr),
                )
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list([
                    'with', 'l', left, [
                        'with', 'r', ['clamp_nonzero', right], [
                            'sdiv',
                            ['mul', 'l', DECIMAL_DIVISOR],
                            'r',
                        ],
                    ]
                ], typ=BaseType('decimal', new_unit), pos=getpos(self.expr))
            else:
                raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Mod):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as modulus arguments!",
                    self.expr,
                )
            if not are_units_compatible(left.typ, right.typ) and not (left.typ.unit or right.typ.unit):  # noqa: E501
                raise TypeMismatchException("Modulus arguments must have same unit", self.expr)
            new_unit = left.typ.unit or right.typ.unit
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    ['assert', right],
                    ['mod', left, right]
                ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list(
                    ['smod', left, ['clamp_nonzero', right]],
                    typ=BaseType(ltyp, new_unit),
                    pos=getpos(self.expr),
                )
            else:
                raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Pow):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as exponential arguments!",
                    self.expr,
                )
            if right.typ.unit:
                raise TypeMismatchException(
                    "Cannot use unit values as exponents",
                    self.expr,
                )
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, ast.Name):
                raise TypeMismatchException(
                    "Cannot use dynamic values as exponents, for unit base types",
                    self.expr,
                )
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    [
                        'assert',
                        [
                            'or',
                            ['or', ['eq', right, 1], ['iszero', right]],
                            ['lt', left, ['exp', left, right]]
                        ],
                    ],
                    ['exp', left, right],
                ], typ=BaseType('uint256'), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                new_unit = left.typ.unit
                if left.typ.unit and not isinstance(self.expr.right, ast.Name):
                    new_unit = {left.typ.unit.copy().popitem()[0]: self.expr.right.n}
                o = LLLnode.from_list(
                    ['exp', left, right],
                    typ=BaseType('int128', new_unit),
                    pos=getpos(self.expr),
                )
            else:
                raise TypeMismatchException('Only whole number exponents are supported', self.expr)
        else:
            raise ParserException(f"Unsupported binary operator: {self.expr.op}", self.expr)

        p = ['seq']

        if pre_alloc_left:
            p.append(pre_alloc_left)
        if pre_alloc_right:
            p.append(pre_alloc_right)

        if o.typ.typ == 'int128':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                o,
                ['mload', MemoryPositions.MAXNUM],
            ])
            return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr))
        elif o.typ.typ == 'decimal':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINDECIMAL],
                o,
                ['mload', MemoryPositions.MAXDECIMAL],
            ])
            return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr))
        if o.typ.typ == 'uint256':
            p.append(o)
            return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr))
        else:
            raise Exception(f"{o} {o.typ}")
Esempio n. 19
0
File: expr.py Progetto: zutobg/vyper
    def compare(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.comparators[0], self.context)

        if isinstance(left.typ, ByteArrayType) and isinstance(right.typ, ByteArrayType):
            if left.typ.maxlen != right.typ.maxlen:
                raise TypeMismatchException('Can only compare bytes of the same length', self.expr)
            if left.typ.maxlen > 32 or right.typ.maxlen > 32:
                raise ParserException('Can only compare bytes of length shorter than 32 bytes', self.expr)
        elif isinstance(self.expr.ops[0], ast.In) and \
           isinstance(right.typ, ListType):
            if not are_units_compatible(left.typ, right.typ.subtype) and not are_units_compatible(right.typ.subtype, left.typ):
                raise TypeMismatchException("Can't use IN comparison with different types!", self.expr)
            return self.build_in_comparator()
        else:
            if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ):
                raise TypeMismatchException("Can't compare values with different units!", self.expr)

        if len(self.expr.ops) != 1:
            raise StructureException("Cannot have a comparison with more than two elements", self.expr)
        if isinstance(self.expr.ops[0], ast.Gt):
            op = 'sgt'
        elif isinstance(self.expr.ops[0], ast.GtE):
            op = 'sge'
        elif isinstance(self.expr.ops[0], ast.LtE):
            op = 'sle'
        elif isinstance(self.expr.ops[0], ast.Lt):
            op = 'slt'
        elif isinstance(self.expr.ops[0], ast.Eq):
            op = 'eq'
        elif isinstance(self.expr.ops[0], ast.NotEq):
            op = 'ne'
        else:
            raise Exception("Unsupported comparison operator")

        # Compare (limited to 32) byte arrays.
        if isinstance(left.typ, ByteArrayType) and isinstance(left.typ, ByteArrayType):
            left = Expr(self.expr.left, self.context).lll_node
            right = Expr(self.expr.comparators[0], self.context).lll_node

            def load_bytearray(side):
                if side.location == 'memory':
                    return ['mload', ['add', 32, side]]
                elif side.location == 'storage':
                    return ['sload', ['add', 1, ['sha3_32', side]]]

            return LLLnode.from_list(
                [op, load_bytearray(left), load_bytearray(right)], typ='bool', pos=getpos(self.expr))

        # Compare other types.
        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            if op not in ('eq', 'ne'):
                raise TypeMismatchException("Invalid type for comparison op", self.expr)
        left_type, right_type = left.typ.typ, right.typ.typ

        # Special Case: comparison of a literal integer. If in valid range allow it to be compared.
        if {left_type, right_type} == {'int128', 'uint256'} and {left.typ.is_literal, right.typ.is_literal} == {True, False}:

            comparison_allowed = False
            if left.typ.is_literal and SizeLimits.in_bounds(right_type, left.value):
                comparison_allowed = True
            elif right.typ.is_literal and SizeLimits.in_bounds(left_type, right.value):
                comparison_allowed = True
            op = self._signed_to_unsigned_comparision_op(op)

            if comparison_allowed:
                return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))

        elif {left_type, right_type} == {'uint256', 'uint256'}:
            op = self._signed_to_unsigned_comparision_op(op)
        elif (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type:
            raise TypeMismatchException(
                'Implicit conversion from {} to {} disallowed, please convert.'.format(left_type, right_type),
                self.expr
            )

        if left_type == right_type:
            return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))
        else:
            raise TypeMismatchException("Unsupported types for comparison: %r %r" % (left_type, right_type), self.expr)
Esempio n. 20
0
    def compare(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.comparators[0], self.context)

        if isinstance(right.typ, NullType):
            raise InvalidLiteralException(
                'Comparison to None is not allowed, compare against a default value.',
                self.expr,
            )

        if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike):
            # TODO: Can this if branch be removed ^
            pass

        elif isinstance(self.expr.ops[0], ast.In) and isinstance(right.typ, ListType):
            if left.typ != right.typ.subtype:
                raise TypeMismatchException(
                    "Can't use IN comparison with different types!",
                    self.expr,
                )
            return self.build_in_comparator()
        else:
            if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ):  # noqa: E501
                raise TypeMismatchException("Can't compare values with different units!", self.expr)

        if len(self.expr.ops) != 1:
            raise StructureException(
                "Cannot have a comparison with more than two elements",
                self.expr,
            )
        if isinstance(self.expr.ops[0], ast.Gt):
            op = 'sgt'
        elif isinstance(self.expr.ops[0], ast.GtE):
            op = 'sge'
        elif isinstance(self.expr.ops[0], ast.LtE):
            op = 'sle'
        elif isinstance(self.expr.ops[0], ast.Lt):
            op = 'slt'
        elif isinstance(self.expr.ops[0], ast.Eq):
            op = 'eq'
        elif isinstance(self.expr.ops[0], ast.NotEq):
            op = 'ne'
        else:
            raise Exception("Unsupported comparison operator")

        # Compare (limited to 32) byte arrays.
        if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike):
            left = Expr(self.expr.left, self.context).lll_node
            right = Expr(self.expr.comparators[0], self.context).lll_node

            length_mismatch = (left.typ.maxlen != right.typ.maxlen)
            left_over_32 = left.typ.maxlen > 32
            right_over_32 = right.typ.maxlen > 32
            if length_mismatch or left_over_32 or right_over_32:
                left_keccak = keccak256_helper(self.expr, [left], None, self.context)
                right_keccak = keccak256_helper(self.expr, [right], None, self.context)

                if op == 'eq' or op == 'ne':
                    return LLLnode.from_list(
                        [op, left_keccak, right_keccak],
                        typ='bool',
                        pos=getpos(self.expr),
                    )

                else:
                    raise ParserException(
                        "Can only compare strings/bytes of length shorter",
                        " than 32 bytes other than equality comparisons",
                        self.expr,
                    )

            else:
                def load_bytearray(side):
                    if side.location == 'memory':
                        return ['mload', ['add', 32, side]]
                    elif side.location == 'storage':
                        return ['sload', ['add', 1, ['sha3_32', side]]]

                return LLLnode.from_list(
                    [op, load_bytearray(left), load_bytearray(right)],
                    typ='bool',
                    pos=getpos(self.expr),
                )

        # Compare other types.
        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            if op not in ('eq', 'ne'):
                raise TypeMismatchException("Invalid type for comparison op", self.expr)
        left_type, right_type = left.typ.typ, right.typ.typ

        # Special Case: comparison of a literal integer. If in valid range allow it to be compared.
        if {left_type, right_type} == {'int128', 'uint256'} and {left.typ.is_literal, right.typ.is_literal} == {True, False}:  # noqa: E501

            comparison_allowed = False
            if left.typ.is_literal and SizeLimits.in_bounds(right_type, left.value):
                comparison_allowed = True
            elif right.typ.is_literal and SizeLimits.in_bounds(left_type, right.value):
                comparison_allowed = True
            op = self._signed_to_unsigned_comparision_op(op)

            if comparison_allowed:
                return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))

        elif {left_type, right_type} == {'uint256', 'uint256'}:
            op = self._signed_to_unsigned_comparision_op(op)
        elif (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type:  # noqa: E501
            raise TypeMismatchException(
                f'Implicit conversion from {left_type} to {right_type} disallowed, please convert.',
                self.expr,
            )

        if left_type == right_type:
            return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr))
        else:
            raise TypeMismatchException(
                f"Unsupported types for comparison: {left_type} {right_type}",
                self.expr,
            )
Esempio n. 21
0
    def parse_return(self):
        from .parser import (make_setter)
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatchException("Not expecting to return a value",
                                            self.stmt)
            return LLLnode.from_list(['return', 0, 0],
                                     typ=None,
                                     pos=getpos(self.stmt))
        if not self.stmt.value:
            raise TypeMismatchException("Expecting to return a value",
                                        self.stmt)

        sub = Expr(self.stmt.value, self.context).lll_node
        self.context.increment_return_counter()
        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            if not isinstance(self.context.return_type, BaseType):
                raise TypeMismatchException(
                    "Trying to return base type %r, output expecting %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            sub = unwrap_location(sub)
            if not are_units_compatible(sub.typ, self.context.return_type):
                raise TypeMismatchException(
                    "Return type units mismatch %r %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            elif is_base_type(sub.typ, self.context.return_type.typ) or \
                    (is_base_type(sub.typ, 'int128') and is_base_type(self.context.return_type, 'int256')):
                return LLLnode.from_list(
                    ['seq', ['mstore', 0, sub], ['return', 0, 32]],
                    typ=None,
                    pos=getpos(self.stmt))
            if sub.typ.is_literal and SizeLimits.in_bounds(
                    self.context.return_type.typ, sub.value):
                return LLLnode.from_list(
                    ['seq', ['mstore', 0, sub], ['return', 0, 32]],
                    typ=None,
                    pos=getpos(self.stmt))
            else:
                raise TypeMismatchException(
                    "Unsupported type conversion: %r to %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayType):
            if not isinstance(self.context.return_type, ByteArrayType):
                raise TypeMismatchException(
                    "Trying to return base type %r, output expecting %r" %
                    (sub.typ, self.context.return_type), self.stmt.value)
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatchException(
                    "Cannot cast from greater max-length %d to shorter max-length %d"
                    % (sub.typ.maxlen, self.context.return_type.maxlen),
                    self.stmt.value)

            zero_padder = LLLnode.from_list(['pass'])
            if sub.typ.maxlen > 0:
                zero_pad_i = self.context.new_placeholder(
                    BaseType('uint256'))  # Iterator used to zero pad memory.
                zero_padder = LLLnode.from_list(
                    [
                        'repeat',
                        zero_pad_i,
                        ['mload', '_loc'],
                        sub.typ.maxlen,
                        [
                            'seq',
                            [
                                'if',
                                ['gt', ['mload', zero_pad_i], sub.typ.maxlen],
                                'break'
                            ],  # stay within allocated bounds
                            [
                                'mstore8',
                                [
                                    'add', ['add', 32, '_loc'],
                                    ['mload', zero_pad_i]
                                ], 0
                            ]
                        ]
                    ],
                    annotation="Zero pad")

            # Returning something already in memory
            if sub.location == 'memory':
                return LLLnode.from_list([
                    'with', '_loc', sub,
                    [
                        'seq', ['mstore', ['sub', '_loc', 32], 32],
                        zero_padder,
                        [
                            'return', ['sub', '_loc', 32],
                            ['ceil32', ['add', ['mload', '_loc'], 64]]
                        ]
                    ]
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))

            # Copying from storage
            elif sub.location == 'storage':
                # Instantiate a byte array at some index
                fake_byte_array = LLLnode(self.context.get_next_mem() + 32,
                                          typ=sub.typ,
                                          location='memory',
                                          pos=getpos(self.stmt))
                o = [
                    'with',
                    '_loc',
                    self.context.get_next_mem() + 32,
                    [
                        'seq',
                        # Copy the data to this byte array
                        make_byte_array_copier(fake_byte_array, sub),
                        # Store the number 32 before it for ABI formatting purposes
                        ['mstore', self.context.get_next_mem(), 32],
                        zero_padder,
                        # Return it
                        [
                            'return',
                            self.context.get_next_mem(),
                            [
                                'add',
                                [
                                    'ceil32',
                                    [
                                        'mload',
                                        self.context.get_next_mem() + 32
                                    ]
                                ], 64
                            ]
                        ]
                    ]
                ]
                return LLLnode.from_list(o, typ=None, pos=getpos(self.stmt))
            else:
                raise Exception("Invalid location: %s" % sub.location)

        elif isinstance(sub.typ, ListType):
            sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0]
            ret_base_type = re.split(r'\(|\[',
                                     str(self.context.return_type.subtype))[0]
            if sub_base_type != ret_base_type:
                raise TypeMismatchException(
                    "List return type %r does not match specified return type, expecting %r"
                    % (sub_base_type, ret_base_type), self.stmt)
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list([
                    'return', sub,
                    get_size_of_type(self.context.return_type) * 32
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))
            else:
                new_sub = LLLnode.from_list(self.context.new_placeholder(
                    self.context.return_type),
                                            typ=self.context.return_type,
                                            location='memory')
                setter = make_setter(new_sub,
                                     sub,
                                     'memory',
                                     pos=getpos(self.stmt))
                return LLLnode.from_list([
                    'seq', setter,
                    [
                        'return', new_sub,
                        get_size_of_type(self.context.return_type) * 32
                    ]
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!",
                                         self.stmt)

            subs = []
            dynamic_offset_counter = LLLnode(
                self.context.get_next_mem(),
                typ=None,
                annotation="dynamic_offset_counter"
            )  # dynamic offset position counter.
            new_sub = LLLnode.from_list(self.context.get_next_mem() + 32,
                                        typ=self.context.return_type,
                                        location='memory',
                                        annotation='new_sub')
            keyz = list(range(len(sub.typ.members)))
            dynamic_offset_start = 32 * len(
                sub.args)  # The static list of args end.
            left_token = LLLnode.from_list('_loc',
                                           typ=new_sub.typ,
                                           location="memory")

            def get_dynamic_offset_value():
                # Get value of dynamic offset counter.
                return ['mload', dynamic_offset_counter]

            def increment_dynamic_offset(dynamic_spot):
                # Increment dyanmic offset counter in memory.
                return [
                    'mstore', dynamic_offset_counter,
                    [
                        'add',
                        ['add', ['ceil32', ['mload', dynamic_spot]], 32],
                        ['mload', dynamic_offset_counter]
                    ]
                ]

            for i, typ in enumerate(keyz):
                arg = sub.args[i]
                variable_offset = LLLnode.from_list(
                    ['add', 32 * i, left_token],
                    typ=arg.typ,
                    annotation='variable_offset')
                if isinstance(arg.typ, ByteArrayType):
                    # Store offset pointer value.
                    subs.append([
                        'mstore', variable_offset,
                        get_dynamic_offset_value()
                    ])

                    # Store dynamic data, from offset pointer onwards.
                    dynamic_spot = LLLnode.from_list(
                        ['add', left_token,
                         get_dynamic_offset_value()],
                        location="memory",
                        typ=arg.typ,
                        annotation='dynamic_spot')
                    subs.append(
                        make_setter(dynamic_spot,
                                    arg,
                                    location="memory",
                                    pos=getpos(self.stmt)))
                    subs.append(increment_dynamic_offset(dynamic_spot))

                elif isinstance(arg.typ, BaseType):
                    subs.append(
                        make_setter(variable_offset,
                                    arg,
                                    "memory",
                                    pos=getpos(self.stmt)))
                else:
                    raise Exception("Can't return type %s as part of tuple",
                                    type(arg.typ))

            setter = LLLnode.from_list([
                'seq', [
                    'mstore', dynamic_offset_counter, dynamic_offset_start
                ], ['with', '_loc', new_sub, ['seq'] + subs]
            ],
                                       typ=None)

            return LLLnode.from_list([
                'seq', setter, ['return', new_sub,
                                get_dynamic_offset_value()]
            ],
                                     typ=None,
                                     pos=getpos(self.stmt))
        else:
            raise TypeMismatchException("Can only return base type!",
                                        self.stmt)
Esempio n. 22
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):
    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:
        if expected_arg == 'num_literal':
            if context.constants.is_constant_of_base_type(
                    arg, ('uint256', 'int128')):
                return context.constants.get_constant(arg.id, None).value
            if isinstance(arg, ast.Num) and get_original_if_0_prefixed(
                    arg, context) is None:
                return arg.n
        elif expected_arg == 'str_literal':
            if isinstance(arg, ast.Str) and get_original_if_0_prefixed(
                    arg, context) is None:
                bytez = b''
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteralException(
                            "Cannot insert special character %r into byte array"
                            % c, arg)
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == 'name_literal':
            if isinstance(arg, ast.Name):
                return arg.id
            elif isinstance(arg, ast.Subscript) and arg.value.id == 'bytes':
                return 'bytes[%s]' % arg.slice.value.n
        elif expected_arg == '*':
            return arg
        elif expected_arg == 'bytes':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == 'string':
            sub = Expr(arg, context).lll_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            # Does not work for unit-endowed types inside compound types, e.g. timestamp[2]
            parsed_expected_type = context.parse_type(
                ast.parse(expected_arg).body[0].value, 'memory')
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)
                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif expected_arg in ('int128', 'uint256') and \
                     isinstance(vsub.typ, BaseType) and \
                     vsub.typ.typ in ('int128', 'uint256') and \
                     vsub.typ.is_literal and \
                     SizeLimits.in_bounds(expected_arg, vsub.value):
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).lll_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).lll_node
    if len(expected_arg_typelist) == 1:
        raise TypeMismatchException(
            "Expecting %s for argument %r of %s" %
            (expected_arg, index, function_name), arg)
    else:
        raise TypeMismatchException(
            "Expecting one of %r for argument %r of %s" %
            (expected_arg_typelist, index, function_name), arg)
        return arg.id
Esempio n. 23
0
def to_decimal(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == 'bytes':
        if in_arg.typ.maxlen > 32:
            raise TypeMismatchException(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to decimal",
                expr,
            )
        num = byte_array_to_num(in_arg, expr, 'int128')
        return LLLnode.from_list(
            ['mul', num, DECIMAL_DIVISOR],
            typ=BaseType('decimal'),
            pos=getpos(expr)
        )

    else:
        _unit = in_arg.typ.unit
        _positional = in_arg.typ.positional

        if input_type == 'uint256':
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds('int128', (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteralException(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(
                        ['mul', in_arg, DECIMAL_DIVISOR],
                        typ=BaseType('decimal', _unit, _positional),
                        pos=getpos(expr)
                    )
            else:
                return LLLnode.from_list(
                    [
                        'uclample',
                        ['mul', in_arg, DECIMAL_DIVISOR],
                        ['mload', MemoryPositions.MAXDECIMAL]
                    ],
                    typ=BaseType('decimal', _unit, _positional),
                    pos=getpos(expr)
                )

        elif input_type == 'address':
            return LLLnode.from_list(
                [
                    'mul',
                    [
                        'signextend',
                        15,
                        [
                            'and',
                            in_arg,
                            (SizeLimits.ADDRSIZE - 1)
                        ],
                    ],
                    DECIMAL_DIVISOR
                ],
                typ=BaseType('decimal', _unit, _positional),
                pos=getpos(expr)
            )

        elif input_type == 'bytes32':
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds('int128', (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteralException(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(
                        ['mul', in_arg, DECIMAL_DIVISOR],
                        typ=BaseType('decimal', _unit, _positional),
                        pos=getpos(expr)
                    )
            else:
                return LLLnode.from_list(
                    [
                        'clamp',
                        ['mload', MemoryPositions.MINDECIMAL],
                        ['mul', in_arg, DECIMAL_DIVISOR],
                        ['mload', MemoryPositions.MAXDECIMAL],
                    ],
                    typ=BaseType('decimal', _unit, _positional),
                    pos=getpos(expr)
                )

        elif input_type in ('int128', 'bool'):
            return LLLnode.from_list(
                ['mul', in_arg, DECIMAL_DIVISOR],
                typ=BaseType('decimal', _unit, _positional),
                pos=getpos(expr)
            )

        else:
            raise InvalidLiteralException(f"Invalid input for decimal: {in_arg}", expr)
Esempio n. 24
0
def process_arg(index, arg, expected_arg_typelist, function_name, context):

    # temporary hack to support abstract types
    if hasattr(expected_arg_typelist, "_id_list"):
        expected_arg_typelist = expected_arg_typelist._id_list

    if isinstance(expected_arg_typelist, Optional):
        expected_arg_typelist = expected_arg_typelist.typ
    if not isinstance(expected_arg_typelist, tuple):
        expected_arg_typelist = (expected_arg_typelist, )

    vsub = None
    for expected_arg in expected_arg_typelist:

        # temporary hack, once we refactor this package none of this will exist
        if hasattr(expected_arg, "_id"):
            expected_arg = expected_arg._id

        if expected_arg == "num_literal":
            if isinstance(arg, (vy_ast.Int, vy_ast.Decimal)):
                return arg.n
        elif expected_arg == "str_literal":
            if isinstance(arg, vy_ast.Str):
                bytez = b""
                for c in arg.s:
                    if ord(c) >= 256:
                        raise InvalidLiteral(
                            f"Cannot insert special character {c} into byte array",
                            arg,
                        )
                    bytez += bytes([ord(c)])
                return bytez
        elif expected_arg == "bytes_literal":
            if isinstance(arg, vy_ast.Bytes):
                return arg.s
        elif expected_arg == "name_literal":
            if isinstance(arg, vy_ast.Name):
                return arg.id
            elif isinstance(arg, vy_ast.Subscript) and arg.value.id == "Bytes":
                return f"Bytes[{arg.slice.value.n}]"
        elif expected_arg == "*":
            return arg
        elif expected_arg == "Bytes":
            sub = Expr(arg, context).ir_node
            if isinstance(sub.typ, ByteArrayType):
                return sub
        elif expected_arg == "String":
            sub = Expr(arg, context).ir_node
            if isinstance(sub.typ, StringType):
                return sub
        else:
            parsed_expected_type = context.parse_type(
                vy_ast.parse_to_ast(expected_arg)[0].value)
            if isinstance(parsed_expected_type, BaseType):
                vsub = vsub or Expr.parse_value_expr(arg, context)

                is_valid_integer = (
                    (expected_arg in INTEGER_TYPES
                     and isinstance(vsub.typ, BaseType))
                    and (vsub.typ.typ in INTEGER_TYPES and vsub.typ.is_literal)
                    and (SizeLimits.in_bounds(expected_arg, vsub.value)))

                if is_base_type(vsub.typ, expected_arg):
                    return vsub
                elif is_valid_integer:
                    return vsub
            else:
                vsub = vsub or Expr(arg, context).ir_node
                if vsub.typ == parsed_expected_type:
                    return Expr(arg, context).ir_node

    if len(expected_arg_typelist) == 1:
        raise TypeMismatch(
            f"Expecting {expected_arg} for argument {index} of {function_name}",
            arg)
    else:
        raise TypeMismatch(
            f"Expecting one of {expected_arg_typelist} for argument {index} of {function_name}",
            arg)
Esempio n. 25
0
 def _can_compare_with_uint256(operand):
     if operand.typ.typ == 'uint256':
         return True
     elif operand.typ.typ == 'int128' and operand.typ.is_literal and SizeLimits.in_bounds('uint256', operand.value):
         return True
     return False
Esempio n. 26
0
    def number(self):
        orignum = get_original_if_0_prefixed(self.expr, self.context)

        if orignum is None and isinstance(self.expr.n, int):
            # Literal (mostly likely) becomes int128
            if SizeLimits.in_bounds('int128', self.expr.n) or self.expr.n < 0:
                return LLLnode.from_list(self.expr.n,
                                         typ=BaseType('int128',
                                                      unit=None,
                                                      is_literal=True),
                                         pos=getpos(self.expr))
            # Literal is large enough (mostly likely) becomes uint256.
            else:
                return LLLnode.from_list(self.expr.n,
                                         typ=BaseType('uint256',
                                                      unit=None,
                                                      is_literal=True),
                                         pos=getpos(self.expr))

        elif isinstance(self.expr.n, float):
            numstring, num, den = get_number_as_fraction(
                self.expr, self.context)
            # if not SizeLimits.in_bounds('decimal', num // den):
            # if not SizeLimits.MINDECIMAL * den <= num <= SizeLimits.MAXDECIMAL * den:
            if not (SizeLimits.MINNUM * den < num < SizeLimits.MAXNUM * den):
                raise InvalidLiteralException(
                    "Number out of range: " + numstring, self.expr)
            if DECIMAL_DIVISOR % den:
                raise InvalidLiteralException(
                    "Too many decimal places: " + numstring, self.expr)
            return LLLnode.from_list(num * DECIMAL_DIVISOR // den,
                                     typ=BaseType('decimal', unit=None),
                                     pos=getpos(self.expr))
        # Binary literal.
        elif orignum[:2] == '0b':
            str_val = orignum[2:]
            total_bits = len(orignum[2:])
            total_bits = total_bits if total_bits % 8 == 0 else total_bits + 8 - (
                total_bits % 8)  # ceil8 to get byte length.
            if len(
                    orignum[2:]
            ) != total_bits:  # Support only full formed bit definitions.
                raise InvalidLiteralException(
                    "Bit notation requires a multiple of 8 bits / 1 byte. {} bit(s) are missing."
                    .format(total_bits - len(orignum[2:])), self.expr)
            byte_len = int(total_bits / 8)
            placeholder = self.context.new_placeholder(ByteArrayType(byte_len))
            seq = []
            seq.append(['mstore', placeholder, byte_len])
            for i in range(0, total_bits, 256):
                section = str_val[i:i + 256]
                int_val = int(section, 2) << (256 - len(section)
                                              )  # bytes are right padded.
                seq.append(['mstore', ['add', placeholder, i + 32], int_val])
            return LLLnode.from_list(
                ['seq'] + seq + [placeholder],
                typ=ByteArrayType(byte_len),
                location='memory',
                pos=getpos(self.expr),
                annotation='Create ByteArray (Binary literal): %s' % str_val)
        elif len(orignum) == 42:
            if checksum_encode(orignum) != orignum:
                raise InvalidLiteralException(
                    """Address checksum mismatch. If you are sure this is the
right address, the correct checksummed form is: %s""" %
                    checksum_encode(orignum), self.expr)
            return LLLnode.from_list(self.expr.n,
                                     typ=BaseType('address', is_literal=True),
                                     pos=getpos(self.expr))
        elif len(orignum) == 66:
            return LLLnode.from_list(self.expr.n,
                                     typ=BaseType('bytes32', is_literal=True),
                                     pos=getpos(self.expr))
        else:
            raise InvalidLiteralException(
                "Cannot read 0x value with length %d. Expecting 42 (address incl 0x) or 66 (bytes32 incl 0x)"
                % len(orignum), self.expr)
Esempio n. 27
0
    def parse_return(self):
        if self.context.return_type is None:
            if self.stmt.value:
                raise TypeMismatchException("Not expecting to return a value",
                                            self.stmt)
            return LLLnode.from_list(
                make_return_stmt(self.stmt, self.context, 0, 0),
                typ=None,
                pos=getpos(self.stmt),
                valency=0,
            )
        if not self.stmt.value:
            raise TypeMismatchException("Expecting to return a value",
                                        self.stmt)

        sub = Expr(self.stmt.value, self.context).lll_node

        # Returning a value (most common case)
        if isinstance(sub.typ, BaseType):
            sub = unwrap_location(sub)

            if not isinstance(self.context.return_type, BaseType):
                raise TypeMismatchException(
                    f"Return type units mismatch {sub.typ} {self.context.return_type}",
                    self.stmt.value)
            elif self.context.return_type != sub.typ and not sub.typ.is_literal:
                raise TypeMismatchException(
                    f"Trying to return base type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )
            elif sub.typ.is_literal and (
                    self.context.return_type.typ == sub.typ
                    or 'int' in self.context.return_type.typ
                    and 'int' in sub.typ.typ):  # noqa: E501
                if not SizeLimits.in_bounds(self.context.return_type.typ,
                                            sub.value):
                    raise InvalidLiteralException(
                        "Number out of range: " + str(sub.value), self.stmt)
                else:
                    return LLLnode.from_list(
                        [
                            'seq', ['mstore', 0, sub],
                            make_return_stmt(self.stmt, self.context, 0, 32)
                        ],
                        typ=None,
                        pos=getpos(self.stmt),
                        valency=0,
                    )
            elif is_base_type(sub.typ, self.context.return_type.typ) or (
                    is_base_type(sub.typ, 'int128') and is_base_type(
                        self.context.return_type, 'int256')):  # noqa: E501
                return LLLnode.from_list(
                    [
                        'seq', ['mstore', 0, sub],
                        make_return_stmt(self.stmt, self.context, 0, 32)
                    ],
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                raise TypeMismatchException(
                    f"Unsupported type conversion: {sub.typ} to {self.context.return_type}",
                    self.stmt.value,
                )
        # Returning a byte array
        elif isinstance(sub.typ, ByteArrayLike):
            if not sub.typ.eq_base(self.context.return_type):
                raise TypeMismatchException(
                    f"Trying to return base type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )
            if sub.typ.maxlen > self.context.return_type.maxlen:
                raise TypeMismatchException(
                    f"Cannot cast from greater max-length {sub.typ.maxlen} to shorter "
                    f"max-length {self.context.return_type.maxlen}",
                    self.stmt.value,
                )

            # loop memory has to be allocated first.
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            # len & bytez placeholder have to be declared after each other at all times.
            len_placeholder = self.context.new_placeholder(
                typ=BaseType('uint256'))
            bytez_placeholder = self.context.new_placeholder(typ=sub.typ)

            if sub.location in ('storage', 'memory'):
                return LLLnode.from_list([
                    'seq',
                    make_byte_array_copier(LLLnode(
                        bytez_placeholder, location='memory', typ=sub.typ),
                                           sub,
                                           pos=getpos(self.stmt)),
                    zero_pad(bytez_placeholder),
                    ['mstore', len_placeholder, 32],
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        len_placeholder,
                        ['ceil32', ['add', ['mload', bytez_placeholder], 64]],
                        loop_memory_position=loop_memory_position,
                    )
                ],
                                         typ=None,
                                         pos=getpos(self.stmt),
                                         valency=0)
            else:
                raise Exception(f"Invalid location: {sub.location}")

        elif isinstance(sub.typ, ListType):
            sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0]
            ret_base_type = re.split(r'\(|\[',
                                     str(self.context.return_type.subtype))[0]
            loop_memory_position = self.context.new_placeholder(
                typ=BaseType('uint256'))
            if sub_base_type != ret_base_type:
                raise TypeMismatchException(
                    f"List return type {sub_base_type} does not match specified "
                    f"return type, expecting {ret_base_type}", self.stmt)
            elif sub.location == "memory" and sub.value != "multi":
                return LLLnode.from_list(
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    ),
                    typ=None,
                    pos=getpos(self.stmt),
                    valency=0,
                )
            else:
                new_sub = LLLnode.from_list(
                    self.context.new_placeholder(self.context.return_type),
                    typ=self.context.return_type,
                    location='memory',
                )
                setter = make_setter(new_sub,
                                     sub,
                                     'memory',
                                     pos=getpos(self.stmt))
                return LLLnode.from_list([
                    'seq', setter,
                    make_return_stmt(
                        self.stmt,
                        self.context,
                        new_sub,
                        get_size_of_type(self.context.return_type) * 32,
                        loop_memory_position=loop_memory_position,
                    )
                ],
                                         typ=None,
                                         pos=getpos(self.stmt))

        # Returning a struct
        elif isinstance(sub.typ, StructType):
            retty = self.context.return_type
            if not isinstance(retty, StructType) or retty.name != sub.typ.name:
                raise TypeMismatchException(
                    f"Trying to return {sub.typ}, output expecting {self.context.return_type}",
                    self.stmt.value,
                )
            return gen_tuple_return(self.stmt, self.context, sub)

        # Returning a tuple.
        elif isinstance(sub.typ, TupleType):
            if not isinstance(self.context.return_type, TupleType):
                raise TypeMismatchException(
                    f"Trying to return tuple type {sub.typ}, output expecting "
                    f"{self.context.return_type}",
                    self.stmt.value,
                )

            if len(self.context.return_type.members) != len(sub.typ.members):
                raise StructureException("Tuple lengths don't match!",
                                         self.stmt)

            # check return type matches, sub type.
            for i, ret_x in enumerate(self.context.return_type.members):
                s_member = sub.typ.members[i]
                sub_type = s_member if isinstance(s_member,
                                                  NodeType) else s_member.typ
                if type(sub_type) is not type(ret_x):
                    raise StructureException(
                        "Tuple return type does not match annotated return. "
                        f"{type(sub_type)} != {type(ret_x)}", self.stmt)
            return gen_tuple_return(self.stmt, self.context, sub)

        else:
            raise TypeMismatchException(f"Can't return type {sub.typ}",
                                        self.stmt)
Esempio n. 28
0
def to_int256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("int256", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("int256", ),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("int256", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("int256"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type == "int128":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "uint256":
        if version_check(begin="constantinople"):
            upper_bound = ["shl", 255, 1]
        else:
            upper_bound = -(2**255)
        return LLLnode.from_list(["uclamplt", in_arg, upper_bound],
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "decimal":
        return LLLnode.from_list(
            ["sdiv", in_arg, DECIMAL_DIVISOR],
            typ=BaseType("int256"),
            pos=getpos(expr),
        )

    elif isinstance(in_arg, LLLnode) and input_type == "bool":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("bytes32", "address"):
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType("int256"),
                       pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("Bytes", "String"):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, "int256")

    else:
        raise InvalidLiteral(f"Invalid input for int256: {in_arg}", expr)
Esempio n. 29
0
    def parse_BinOp(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            return

        arithmetic_pair = {left.typ.typ, right.typ.typ}
        pos = getpos(self.expr)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {"uint256", "int128"}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds("uint256", right.value):
                right = LLLnode.from_list(
                    right.value, typ=BaseType("uint256", None, is_literal=True), pos=pos,
                )

            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds("uint256", left.value):
                left = LLLnode.from_list(
                    left.value, typ=BaseType("uint256", None, is_literal=True), pos=pos,
                )

        if left.typ.typ == "decimal" and isinstance(self.expr.op, vy_ast.Pow):
            return

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            return

        ltyp, rtyp = left.typ.typ, right.typ.typ
        arith = None
        if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)):
            new_typ = BaseType(ltyp)
            op = "add" if isinstance(self.expr.op, vy_ast.Add) else "sub"

            if ltyp == "uint256" and isinstance(self.expr.op, vy_ast.Add):
                # safeadd
                arith = ["seq", ["assert", ["ge", ["add", "l", "r"], "l"]], ["add", "l", "r"]]

            elif ltyp == "uint256" and isinstance(self.expr.op, vy_ast.Sub):
                # safesub
                arith = ["seq", ["assert", ["ge", "l", "r"]], ["sub", "l", "r"]]

            elif ltyp == rtyp:
                arith = [op, "l", "r"]

        elif isinstance(self.expr.op, vy_ast.Mult):
            new_typ = BaseType(ltyp)
            if ltyp == rtyp == "uint256":
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        ["assert", ["or", ["eq", ["div", "ans", "l"], "r"], ["iszero", "l"]]],
                        "ans",
                    ],
                ]

            elif ltyp == rtyp == "int128":
                # TODO should this be 'smul' (note edge cases in YP for smul)
                arith = ["mul", "l", "r"]

            elif ltyp == rtyp == "decimal":
                # TODO should this be smul
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        ["assert", ["or", ["eq", ["sdiv", "ans", "l"], "r"], ["iszero", "l"]]],
                        ["sdiv", "ans", DECIMAL_DIVISOR],
                    ],
                ]

        elif isinstance(self.expr.op, vy_ast.Div):
            if right.typ.is_literal and right.value == 0:
                return

            new_typ = BaseType(ltyp)

            if right.typ.is_literal:
                divisor = "r"
            else:
                # only apply the non-zero clamp when r is not a constant
                divisor = ["clamp_nonzero", "r"]

            if ltyp == rtyp == "uint256":
                arith = ["div", "l", divisor]

            elif ltyp == rtyp == "int128":
                arith = ["sdiv", "l", divisor]

            elif ltyp == rtyp == "decimal":
                arith = [
                    "sdiv",
                    # TODO check overflow cases, also should it be smul
                    ["mul", "l", DECIMAL_DIVISOR],
                    divisor,
                ]

        elif isinstance(self.expr.op, vy_ast.Mod):
            if right.typ.is_literal and right.value == 0:
                return

            new_typ = BaseType(ltyp)

            if right.typ.is_literal:
                divisor = "r"
            else:
                # only apply the non-zero clamp when r is not a constant
                divisor = ["clamp_nonzero", "r"]

            if ltyp == rtyp == "uint256":
                arith = ["mod", "l", divisor]
            elif ltyp == rtyp:
                # TODO should this be regular mod
                arith = ["smod", "l", divisor]

        elif isinstance(self.expr.op, vy_ast.Pow):
            if ltyp != "int128" and ltyp != "uint256" and isinstance(self.expr.right, vy_ast.Name):
                return
            new_typ = BaseType(ltyp)

            if self.expr.left.get("value") == 1:
                return LLLnode.from_list([1], typ=new_typ, pos=pos)
            if self.expr.left.get("value") == 0:
                return LLLnode.from_list(["iszero", right], typ=new_typ, pos=pos)

            if ltyp == "int128":
                is_signed = True
                num_bits = 128
            else:
                is_signed = False
                num_bits = 256

            if isinstance(self.expr.left, vy_ast.Int):
                value = self.expr.left.value
                upper_bound = calculate_largest_power(value, num_bits, is_signed) + 1
                # for signed integers, this also prevents negative values
                clamp = ["lt", right, upper_bound]
                return LLLnode.from_list(
                    ["seq", ["assert", clamp], ["exp", left, right]], typ=new_typ, pos=pos,
                )
            elif isinstance(self.expr.right, vy_ast.Int):
                value = self.expr.right.value
                upper_bound = calculate_largest_base(value, num_bits, is_signed) + 1
                if is_signed:
                    clamp = ["and", ["slt", left, upper_bound], ["sgt", left, -upper_bound]]
                else:
                    clamp = ["lt", left, upper_bound]
                return LLLnode.from_list(
                    ["seq", ["assert", clamp], ["exp", left, right]], typ=new_typ, pos=pos,
                )
            else:
                # `a ** b` where neither `a` or `b` are known
                # TODO this is currently unreachable, once we implement a way to do it safely
                # remove the check in `vyper/context/types/value/numeric.py`
                return

        if arith is None:
            return

        p = ["seq"]
        if new_typ.typ == "int128":
            p.append(int128_clamp(arith))
        elif new_typ.typ == "decimal":
            p.append(
                [
                    "clamp",
                    ["mload", MemoryPositions.MINDECIMAL],
                    arith,
                    ["mload", MemoryPositions.MAXDECIMAL],
                ]
            )
        elif new_typ.typ == "uint256":
            p.append(arith)
        else:
            return

        p = ["with", "l", left, ["with", "r", right, p]]
        return LLLnode.from_list(p, typ=new_typ, pos=pos)
Esempio n. 30
0
def _convert_decimal_to_int(val, o_typ):
    # note special behavior for decimal: catch OOB before truncation.
    if not SizeLimits.in_bounds(o_typ, val):
        return None

    return round_towards_zero(val)