Пример #1
0
    def arithmetic(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)
        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatchException(
                "Unsupported types for arithmetic op: %r %r" %
                (left.typ, right.typ), self.expr)

        arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Special Case: Simplify any literal to literal arithmetic at compile time.
        if left.typ.is_literal and right.typ.is_literal and \
           isinstance(right.value, int) and isinstance(left.value, int):

            if isinstance(self.expr.op, ast.Add):
                val = left.value + right.value
            elif isinstance(self.expr.op, ast.Sub):
                val = left.value - right.value
            elif isinstance(self.expr.op, ast.Mult):
                val = left.value * right.value
            elif isinstance(self.expr.op, ast.Div):
                val = left.value // right.value
            elif isinstance(self.expr.op, ast.Mod):
                val = left.value % right.value
            elif isinstance(self.expr.op, ast.Pow):
                val = left.value**right.value
            else:
                raise ParserException(
                    'Unsupported literal operator: %s' %
                    str(type(self.expr.op)), self.expr)

            num = ast.Num(val)
            num.source_code = self.expr.source_code
            num.lineno = self.expr.lineno
            num.col_offset = self.expr.col_offset

            return Expr.parse_value_expr(num, self.context)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds(
                    'uint256', right.value):
                right = LLLnode.from_list(right.value,
                                          typ=BaseType('uint256',
                                                       None,
                                                       is_literal=True),
                                          pos=getpos(self.expr))
                arithmetic_pair = {left.typ.typ, right.typ.typ}
            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds(
                    'uint256', left.value):
                left = LLLnode.from_list(left.value,
                                         typ=BaseType('uint256',
                                                      None,
                                                      is_literal=True),
                                         pos=getpos(self.expr))
                arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatchException(
                "Cannot implicitly convert {} to {}.".format(
                    left.typ.typ, right.typ.typ), self.expr)

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (ast.Add, ast.Sub)):
            if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None:
                raise TypeMismatchException(
                    "Unit mismatch: %r %r" % (left.typ.unit, right.typ.unit),
                    self.expr)
            if left.typ.positional and right.typ.positional and isinstance(
                    self.expr.op, ast.Add):
                raise TypeMismatchException("Cannot add two positional units!",
                                            self.expr)
            new_unit = left.typ.unit or right.typ.unit
            new_positional = left.typ.positional ^ right.typ.positional  # xor, as subtracting two positionals gives a delta
            op = 'add' if isinstance(self.expr.op, ast.Add) else 'sub'
            if ltyp == 'uint256' and isinstance(self.expr.op, ast.Add):
                o = LLLnode.from_list(
                    [
                        'seq',
                        # Checks that: a + b >= a
                        ['assert', ['ge', ['add', left, right], left]],
                        ['add', left, right]
                    ],
                    typ=BaseType('uint256', new_unit, new_positional),
                    pos=getpos(self.expr))
            elif ltyp == 'uint256' and isinstance(self.expr.op, ast.Sub):
                o = LLLnode.from_list(
                    [
                        'seq',
                        # Checks that: a >= b
                        ['assert', ['ge', left, right]],
                        ['sub', left, right]
                    ],
                    typ=BaseType('uint256', new_unit, new_positional),
                    pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list([op, left, right],
                                      typ=BaseType(ltyp, new_unit,
                                                   new_positional),
                                      pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation '%r(%r, %r)'" %
                                (op, ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Mult):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot multiply positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(
                    [
                        'seq',
                        # Checks that: a == 0 || a / b == b
                        [
                            'assert',
                            [
                                'or', ['iszero', left],
                                [
                                    'eq', ['div', ['mul', left, right], left],
                                    right
                                ]
                            ]
                        ],
                        ['mul', left, right]
                    ],
                    typ=BaseType('uint256', new_unit),
                    pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(['mul', left, right],
                                      typ=BaseType('int128', new_unit),
                                      pos=getpos(self.expr))
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list([
                    'with', 'r', right,
                    [
                        'with', 'l', left,
                        [
                            'with', 'ans', ['mul', 'l', 'r'],
                            [
                                'seq',
                                [
                                    'assert',
                                    [
                                        'or', [
                                            'eq', ['sdiv', 'ans', 'l'], 'r'
                                        ], ['iszero', 'l']
                                    ]
                                ], ['sdiv', 'ans', DECIMAL_DIVISOR]
                            ]
                        ]
                    ]
                ],
                                      typ=BaseType('decimal', new_unit),
                                      pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation 'mul(%r, %r)'" %
                                (ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Div):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot divide positional values!",
                                            self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit, div=True)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(
                    [
                        'seq',
                        # Checks that:  b != 0
                        ['assert', right],
                        ['div', left, right]
                    ],
                    typ=BaseType('uint256', new_unit),
                    pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(['sdiv', left, ['clamp_nonzero', right]],
                                      typ=BaseType('int128', new_unit),
                                      pos=getpos(self.expr))
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list([
                    'with', 'l', left,
                    [
                        'with', 'r', ['clamp_nonzero', right],
                        ['sdiv', ['mul', 'l', DECIMAL_DIVISOR], 'r']
                    ]
                ],
                                      typ=BaseType('decimal', new_unit),
                                      pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation 'div(%r, %r)'" %
                                (ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Mod):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as modulus arguments!",
                    self.expr)
            if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None:
                raise TypeMismatchException(
                    "Modulus arguments must have same unit", self.expr)
            new_unit = left.typ.unit or right.typ.unit
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list(
                    ['seq', ['assert', right], ['mod', left, right]],
                    typ=BaseType('uint256', new_unit),
                    pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list(['smod', left, ['clamp_nonzero', right]],
                                      typ=BaseType(ltyp, new_unit),
                                      pos=getpos(self.expr))
            else:
                raise Exception("Unsupported Operation 'mod(%r, %r)'" %
                                (ltyp, rtyp))
        elif isinstance(self.expr.op, ast.Pow):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as exponential arguments!",
                    self.expr)
            if right.typ.unit:
                raise TypeMismatchException(
                    "Cannot use unit values as exponents", self.expr)
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(
                    self.expr.right, ast.Name):
                raise TypeMismatchException(
                    "Cannot use dynamic values as exponents, for unit base types",
                    self.expr)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    [
                        'assert',
                        [
                            'or', ['or', ['eq', right, 1], ['iszero', right]],
                            ['lt', left, ['exp', left, right]]
                        ]
                    ], ['exp', left, right]
                ],
                                      typ=BaseType('uint256'),
                                      pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                new_unit = left.typ.unit
                if left.typ.unit and not isinstance(self.expr.right, ast.Name):
                    new_unit = {
                        left.typ.unit.copy().popitem()[0]: self.expr.right.n
                    }
                o = LLLnode.from_list(['exp', left, right],
                                      typ=BaseType('int128', new_unit),
                                      pos=getpos(self.expr))
            else:
                raise TypeMismatchException(
                    'Only whole number exponents are supported', self.expr)
        else:
            raise Exception("Unsupported binop: %r" % self.expr.op)
        if o.typ.typ == 'int128':
            return LLLnode.from_list([
                'clamp', ['mload', MemoryPositions.MINNUM], o,
                ['mload', MemoryPositions.MAXNUM]
            ],
                                     typ=o.typ,
                                     pos=getpos(self.expr))
        elif o.typ.typ == 'decimal':
            return LLLnode.from_list([
                'clamp', ['mload', MemoryPositions.MINDECIMAL], o,
                ['mload', MemoryPositions.MAXDECIMAL]
            ],
                                     typ=o.typ,
                                     pos=getpos(self.expr))
        if o.typ.typ == 'uint256':
            return o
        else:
            raise Exception("%r %r" % (o, o.typ))
Пример #2
0
 def arithmetic(self):
     left = Expr.parse_value_expr(self.expr.left, self.context)
     right = Expr.parse_value_expr(self.expr.right, self.context)
     if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
         raise TypeMismatchException(
             "Unsupported types for arithmetic op: %r %r" %
             (left.typ, right.typ), self.expr)
     ltyp, rtyp = left.typ.typ, right.typ.typ
     if isinstance(self.expr.op, (ast.Add, ast.Sub)):
         if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None:
             raise TypeMismatchException(
                 "Unit mismatch: %r %r" % (left.typ.unit, right.typ.unit),
                 self.expr)
         if left.typ.positional and right.typ.positional and isinstance(
                 self.expr.op, ast.Add):
             raise TypeMismatchException("Cannot add two positional units!",
                                         self.expr)
         new_unit = left.typ.unit or right.typ.unit
         new_positional = left.typ.positional ^ right.typ.positional  # xor, as subtracting two positionals gives a delta
         op = 'add' if isinstance(self.expr.op, ast.Add) else 'sub'
         if ltyp == rtyp:
             o = LLLnode.from_list([op, left, right],
                                   typ=BaseType(ltyp, new_unit,
                                                new_positional),
                                   pos=getpos(self.expr))
         elif ltyp == 'int128' and rtyp == 'decimal':
             o = LLLnode.from_list(
                 [op, ['mul', left, DECIMAL_DIVISOR], right],
                 typ=BaseType('decimal', new_unit, new_positional),
                 pos=getpos(self.expr))
         elif ltyp == 'decimal' and rtyp == 'int128':
             o = LLLnode.from_list(
                 [op, left, ['mul', right, DECIMAL_DIVISOR]],
                 typ=BaseType('decimal', new_unit, new_positional),
                 pos=getpos(self.expr))
         else:
             raise Exception("Unsupported Operation '%r(%r, %r)'" %
                             (op, ltyp, rtyp))
     elif isinstance(self.expr.op, ast.Mult):
         if left.typ.positional or right.typ.positional:
             raise TypeMismatchException(
                 "Cannot multiply positional values!", self.expr)
         new_unit = combine_units(left.typ.unit, right.typ.unit)
         if ltyp == rtyp == 'int128':
             o = LLLnode.from_list(['mul', left, right],
                                   typ=BaseType('int128', new_unit),
                                   pos=getpos(self.expr))
         elif ltyp == rtyp == 'decimal':
             o = LLLnode.from_list([
                 'with', 'r', right,
                 [
                     'with', 'l', left,
                     [
                         'with', 'ans', ['mul', 'l', 'r'],
                         [
                             'seq',
                             [
                                 'assert',
                                 [
                                     'or', [
                                         'eq', ['sdiv', 'ans', 'l'], 'r'
                                     ], ['not', 'l']
                                 ]
                             ], ['sdiv', 'ans', DECIMAL_DIVISOR]
                         ]
                     ]
                 ]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         elif (ltyp == 'int128'
               and rtyp == 'decimal') or (ltyp == 'decimal'
                                          and rtyp == 'int128'):
             o = LLLnode.from_list([
                 'with', 'r', right,
                 [
                     'with', 'l', left,
                     [
                         'with', 'ans', ['mul', 'l', 'r'],
                         [
                             'seq',
                             [
                                 'assert',
                                 [
                                     'or', [
                                         'eq', ['sdiv', 'ans', 'l'], 'r'
                                     ], ['not', 'l']
                                 ]
                             ], 'ans'
                         ]
                     ]
                 ]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         else:
             raise Exception("Unsupported Operation 'mul(%r, %r)'" %
                             (ltyp, rtyp))
     elif isinstance(self.expr.op, ast.Div):
         if left.typ.positional or right.typ.positional:
             raise TypeMismatchException("Cannot divide positional values!",
                                         self.expr)
         new_unit = combine_units(left.typ.unit, right.typ.unit, div=True)
         if ltyp == rtyp == 'int128':
             o = LLLnode.from_list([
                 'sdiv', ['mul', left, DECIMAL_DIVISOR],
                 ['clamp_nonzero', right]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         elif ltyp == rtyp == 'decimal':
             o = LLLnode.from_list([
                 'with', 'l', left,
                 [
                     'with', 'r', ['clamp_nonzero', right],
                     ['sdiv', ['mul', 'l', DECIMAL_DIVISOR], 'r']
                 ]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         elif ltyp == 'int128' and rtyp == 'decimal':
             o = LLLnode.from_list([
                 'sdiv', ['mul', left, DECIMAL_DIVISOR**2],
                 ['clamp_nonzero', right]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         elif ltyp == 'decimal' and rtyp == 'int128':
             o = LLLnode.from_list(['sdiv', left, ['clamp_nonzero', right]],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         else:
             raise Exception("Unsupported Operation 'div(%r, %r)'" %
                             (ltyp, rtyp))
     elif isinstance(self.expr.op, ast.Mod):
         if left.typ.positional or right.typ.positional:
             raise TypeMismatchException(
                 "Cannot use positional values as modulus arguments!",
                 self.expr)
         if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None:
             raise TypeMismatchException(
                 "Modulus arguments must have same unit", self.expr)
         new_unit = left.typ.unit or right.typ.unit
         if ltyp == rtyp:
             o = LLLnode.from_list(['smod', left, ['clamp_nonzero', right]],
                                   typ=BaseType(ltyp, new_unit),
                                   pos=getpos(self.expr))
         elif ltyp == 'decimal' and rtyp == 'int128':
             o = LLLnode.from_list([
                 'smod', left,
                 ['mul', ['clamp_nonzero', right], DECIMAL_DIVISOR]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         elif ltyp == 'int128' and rtyp == 'decimal':
             o = LLLnode.from_list([
                 'smod', ['mul', left, DECIMAL_DIVISOR],
                 ['clamp_nonzero', right]
             ],
                                   typ=BaseType('decimal', new_unit),
                                   pos=getpos(self.expr))
         else:
             raise Exception("Unsupported Operation 'mod(%r, %r)'" %
                             (ltyp, rtyp))
     elif isinstance(self.expr.op, ast.Pow):
         if left.typ.positional or right.typ.positional:
             raise TypeMismatchException(
                 "Cannot use positional values as exponential arguments!",
                 self.expr)
         if right.typ.unit:
             raise TypeMismatchException(
                 "Cannot use unit values as exponents", self.expr)
         if ltyp != 'int128' and isinstance(self.expr.right, ast.Name):
             raise TypeMismatchException(
                 "Cannot use dynamic values as exponents, for unit base types",
                 self.expr)
         if ltyp == rtyp == 'int128':
             new_unit = left.typ.unit
             if left.typ.unit and not isinstance(self.expr.right, ast.Name):
                 new_unit = {
                     left.typ.unit.copy().popitem()[0]: self.expr.right.n
                 }
             o = LLLnode.from_list(['exp', left, right],
                                   typ=BaseType('int128', new_unit),
                                   pos=getpos(self.expr))
         else:
             raise TypeMismatchException(
                 'Only whole number exponents are supported', self.expr)
     else:
         raise Exception("Unsupported binop: %r" % self.expr.op)
     if o.typ.typ == 'int128':
         return LLLnode.from_list([
             'clamp', ['mload', MemoryPositions.MINNUM], o,
             ['mload', MemoryPositions.MAXNUM]
         ],
                                  typ=o.typ,
                                  pos=getpos(self.expr))
     elif o.typ.typ == 'decimal':
         return LLLnode.from_list([
             'clamp', ['mload', MemoryPositions.MINDECIMAL], o,
             ['mload', MemoryPositions.MAXDECIMAL]
         ],
                                  typ=o.typ,
                                  pos=getpos(self.expr))
     else:
         raise Exception("%r %r" % (o, o.typ))
Пример #3
0
    def arithmetic(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatchException(
                f"Unsupported types for arithmetic op: {left.typ} {right.typ}",
                self.expr,
            )

        arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Special Case: Simplify any literal to literal arithmetic at compile time.
        if left.typ.is_literal and right.typ.is_literal and \
           isinstance(right.value, int) and isinstance(left.value, int) and \
           arithmetic_pair.issubset({'uint256', 'int128'}):

            if isinstance(self.expr.op, vy_ast.Add):
                val = left.value + right.value
            elif isinstance(self.expr.op, vy_ast.Sub):
                val = left.value - right.value
            elif isinstance(self.expr.op, vy_ast.Mult):
                val = left.value * right.value
            elif isinstance(self.expr.op, vy_ast.Pow):
                val = left.value ** right.value
            elif isinstance(self.expr.op, (vy_ast.Div, vy_ast.Mod)):
                if right.value == 0:
                    raise ZeroDivisionException(
                        "integer division or modulo by zero",
                        self.expr,
                    )
                if isinstance(self.expr.op, vy_ast.Div):
                    val = left.value // right.value
                elif isinstance(self.expr.op, vy_ast.Mod):
                    # modified modulo logic to remain consistent with EVM behaviour
                    val = abs(left.value) % abs(right.value)
                    if left.value < 0:
                        val = -val
            else:
                raise ParserException(
                    f'Unsupported literal operator: {type(self.expr.op)}',
                    self.expr,
                )

            num = vy_ast.Int(n=val)
            num.full_source_code = self.expr.full_source_code
            num.node_source_code = self.expr.node_source_code
            num.lineno = self.expr.lineno
            num.col_offset = self.expr.col_offset
            num.end_lineno = self.expr.end_lineno
            num.end_col_offset = self.expr.end_col_offset

            return Expr.parse_value_expr(num, self.context)

        pos = getpos(self.expr)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value):
                right = LLLnode.from_list(
                    right.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=pos,
                )

            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value):
                left = LLLnode.from_list(
                    left.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=pos,
                )

        if left.typ.typ == "decimal" and isinstance(self.expr.op, vy_ast.Pow):
            raise TypeMismatchException(
                "Cannot perform exponentiation on decimal values.",
                self.expr,
            )

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatchException(
                f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.",
                self.expr,
            )

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)):
            if left.typ.unit != right.typ.unit and left.typ.unit != {} and right.typ.unit != {}:
                raise TypeMismatchException(
                    f"Unit mismatch: {left.typ.unit} {right.typ.unit}",
                    self.expr,
                )
            if (
                left.typ.positional and
                right.typ.positional and
                isinstance(self.expr.op, vy_ast.Add)
            ):
                raise TypeMismatchException(
                    "Cannot add two positional units!",
                    self.expr,
                )
            new_unit = left.typ.unit or right.typ.unit

            # xor, as subtracting two positionals gives a delta
            new_positional = left.typ.positional ^ right.typ.positional

            new_typ = BaseType(ltyp, new_unit, new_positional)

            op = 'add' if isinstance(self.expr.op, vy_ast.Add) else 'sub'

            if ltyp == 'uint256' and isinstance(self.expr.op, vy_ast.Add):
                # safeadd
                arith = ['seq',
                         ['assert', ['ge', ['add', 'l', 'r'], 'l']],
                         ['add', 'l', 'r']]

            elif ltyp == 'uint256' and isinstance(self.expr.op, vy_ast.Sub):
                # safesub
                arith = ['seq',
                         ['assert', ['ge', 'l', 'r']],
                         ['sub', 'l', 'r']]

            elif ltyp == rtyp:
                arith = [op, 'l', 'r']

            else:
                raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, vy_ast.Mult):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot multiply positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit)
            new_typ = BaseType(ltyp, new_unit)

            if ltyp == rtyp == 'uint256':
                arith = ['with', 'ans', ['mul', 'l', 'r'],
                         ['seq',
                             ['assert',
                                 ['or',
                                     ['eq', ['div', 'ans', 'l'], 'r'],
                                     ['iszero', 'l']]],
                             'ans']]

            elif ltyp == rtyp == 'int128':
                # TODO should this be 'smul' (note edge cases in YP for smul)
                arith = ['mul', 'l', 'r']

            elif ltyp == rtyp == 'decimal':
                # TODO should this be smul
                arith = ['with', 'ans', ['mul', 'l', 'r'],
                         ['seq',
                             ['assert',
                                 ['or',
                                     ['eq', ['sdiv', 'ans', 'l'], 'r'],
                                     ['iszero', 'l']]],
                             ['sdiv', 'ans', DECIMAL_DIVISOR]]]
            else:
                raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, vy_ast.Div):
            if right.typ.is_literal and right.value == 0:
                raise ZeroDivisionException("Cannot divide by 0.", self.expr)
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot divide positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit, div=True)
            new_typ = BaseType(ltyp, new_unit)
            if ltyp == rtyp == 'uint256':
                arith = ['div', 'l', ['clamp_nonzero', 'r']]

            elif ltyp == rtyp == 'int128':
                arith = ['sdiv', 'l', ['clamp_nonzero', 'r']]

            elif ltyp == rtyp == 'decimal':
                arith = ['sdiv',
                         # TODO check overflow cases, also should it be smul
                         ['mul', 'l', DECIMAL_DIVISOR],
                         ['clamp_nonzero', 'r']]

            else:
                raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'")

        elif isinstance(self.expr.op, vy_ast.Mod):
            if right.typ.is_literal and right.value == 0:
                raise ZeroDivisionException("Cannot calculate modulus of 0.", self.expr)
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as modulus arguments!",
                    self.expr,
                )

            if not are_units_compatible(left.typ, right.typ) and not (left.typ.unit or right.typ.unit):  # noqa: E501
                raise TypeMismatchException("Modulus arguments must have same unit", self.expr)
            new_unit = left.typ.unit or right.typ.unit
            new_typ = BaseType(ltyp, new_unit)

            if ltyp == rtyp == 'uint256':
                arith = ['mod', 'l', ['clamp_nonzero', 'r']]
            elif ltyp == rtyp:
                # TODO should this be regular mod
                arith = ['smod', 'l', ['clamp_nonzero', 'r']]

            else:
                raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, vy_ast.Pow):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as exponential arguments!",
                    self.expr,
                )
            if right.typ.unit:
                raise TypeMismatchException(
                    "Cannot use unit values as exponents",
                    self.expr,
                )
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, vy_ast.Name):
                raise TypeMismatchException(
                    "Cannot use dynamic values as exponents, for unit base types",
                    self.expr,
                )
            new_unit = left.typ.unit
            if left.typ.unit and not isinstance(self.expr.right, vy_ast.Name):
                new_unit = {left.typ.unit.copy().popitem()[0]: self.expr.right.n}
            new_typ = BaseType(ltyp, new_unit)

            if ltyp == rtyp == 'uint256':
                arith = ['seq',
                         ['assert',
                             ['or',
                                 # r == 1 | iszero(r)
                                 # could be simplified to ~(r & 1)
                                 ['or', ['eq', 'r', 1], ['iszero', 'r']],
                                 ['lt', 'l', ['exp', 'l', 'r']]]],
                         ['exp', 'l', 'r']]
            elif ltyp == rtyp == 'int128':
                arith = ['exp', 'l', 'r']

            else:
                raise TypeMismatchException('Only whole number exponents are supported', self.expr)
        else:
            raise ParserException(f"Unsupported binary operator: {self.expr.op}", self.expr)

        p = ['seq']

        if new_typ.typ == 'int128':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                arith,
                ['mload', MemoryPositions.MAXNUM],
            ])
        elif new_typ.typ == 'decimal':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINDECIMAL],
                arith,
                ['mload', MemoryPositions.MAXDECIMAL],
            ])
        elif new_typ.typ == 'uint256':
            p.append(arith)
        else:
            raise Exception(f"{arith} {new_typ}")

        p = ['with', 'l', left, ['with', 'r', right, p]]
        return LLLnode.from_list(p, typ=new_typ, pos=pos)
Пример #4
0
    def arithmetic(self):
        pre_alloc_left, left = self.arithmetic_get_reference(self.expr.left)
        pre_alloc_right, right = self.arithmetic_get_reference(self.expr.right)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            raise TypeMismatchException(
                f"Unsupported types for arithmetic op: {left.typ} {right.typ}",
                self.expr,
            )

        arithmetic_pair = {left.typ.typ, right.typ.typ}

        # Special Case: Simplify any literal to literal arithmetic at compile time.
        if left.typ.is_literal and right.typ.is_literal and \
           isinstance(right.value, int) and isinstance(left.value, int) and \
           arithmetic_pair.issubset({'uint256', 'int128'}):

            if isinstance(self.expr.op, ast.Add):
                val = left.value + right.value
            elif isinstance(self.expr.op, ast.Sub):
                val = left.value - right.value
            elif isinstance(self.expr.op, ast.Mult):
                val = left.value * right.value
            elif isinstance(self.expr.op, ast.Div):
                val = left.value // right.value
            elif isinstance(self.expr.op, ast.Mod):
                val = left.value % right.value
            elif isinstance(self.expr.op, ast.Pow):
                val = left.value ** right.value
            else:
                raise ParserException(
                    f'Unsupported literal operator: {str(type(self.expr.op))}',
                    self.expr,
                )

            num = ast.Num(n=val)
            num.source_code = self.expr.source_code
            num.lineno = self.expr.lineno
            num.col_offset = self.expr.col_offset
            num.end_lineno = self.expr.end_lineno
            num.end_col_offset = self.expr.end_col_offset

            return Expr.parse_value_expr(num, self.context)

        # Special case with uint256 were int literal may be casted.
        if arithmetic_pair == {'uint256', 'int128'}:
            # Check right side literal.
            if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value):
                right = LLLnode.from_list(
                    right.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=getpos(self.expr),
                )

            # Check left side literal.
            elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value):
                left = LLLnode.from_list(
                    left.value,
                    typ=BaseType('uint256', None, is_literal=True),
                    pos=getpos(self.expr),
                )

        # Only allow explicit conversions to occur.
        if left.typ.typ != right.typ.typ:
            raise TypeMismatchException(
                f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.",
                self.expr,
            )

        ltyp, rtyp = left.typ.typ, right.typ.typ
        if isinstance(self.expr.op, (ast.Add, ast.Sub)):
            if left.typ.unit != right.typ.unit and left.typ.unit != {} and right.typ.unit != {}:
                raise TypeMismatchException(
                    f"Unit mismatch: {left.typ.unit} {right.typ.unit}",
                    self.expr,
                )
            if left.typ.positional and right.typ.positional and isinstance(self.expr.op, ast.Add):
                raise TypeMismatchException(
                    "Cannot add two positional units!",
                    self.expr,
                )
            new_unit = left.typ.unit or right.typ.unit

            # xor, as subtracting two positionals gives a delta
            new_positional = left.typ.positional ^ right.typ.positional

            op = 'add' if isinstance(self.expr.op, ast.Add) else 'sub'
            if ltyp == 'uint256' and isinstance(self.expr.op, ast.Add):
                o = LLLnode.from_list([
                    'seq',
                    # Checks that: a + b >= a
                    ['assert', ['ge', ['add', left, right], left]],
                    ['add', left, right],
                ], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr))
            elif ltyp == 'uint256' and isinstance(self.expr.op, ast.Sub):
                o = LLLnode.from_list([
                    'seq',
                    # Checks that: a >= b
                    ['assert', ['ge', left, right]],
                    ['sub', left, right]
                ], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list(
                    [op, left, right],
                    typ=BaseType(ltyp, new_unit, new_positional),
                    pos=getpos(self.expr),
                )
            else:
                raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Mult):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot multiply positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'if',
                    ['eq', left, 0],
                    [0],
                    [
                        'seq', ['assert', ['eq', ['div', ['mul', left, right], left], right]],
                        ['mul', left, right]
                    ],
                ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(
                    ['mul', left, right],
                    typ=BaseType('int128', new_unit),
                    pos=getpos(self.expr),
                )
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list([
                    'with', 'r', right, [
                        'with', 'l', left, [
                            'with', 'ans', ['mul', 'l', 'r'],
                            [
                                'seq',
                                [
                                    'assert',
                                    ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['iszero', 'l']]
                                ],
                                ['sdiv', 'ans', DECIMAL_DIVISOR],
                            ],
                        ],
                    ],
                ], typ=BaseType('decimal', new_unit), pos=getpos(self.expr))
            else:
                raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Div):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException("Cannot divide positional values!", self.expr)
            new_unit = combine_units(left.typ.unit, right.typ.unit, div=True)
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    # Checks that:  b != 0
                    ['assert', right],
                    ['div', left, right],
                ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                o = LLLnode.from_list(
                    ['sdiv', left, ['clamp_nonzero', right]],
                    typ=BaseType('int128', new_unit),
                    pos=getpos(self.expr),
                )
            elif ltyp == rtyp == 'decimal':
                o = LLLnode.from_list([
                    'with', 'l', left, [
                        'with', 'r', ['clamp_nonzero', right], [
                            'sdiv',
                            ['mul', 'l', DECIMAL_DIVISOR],
                            'r',
                        ],
                    ]
                ], typ=BaseType('decimal', new_unit), pos=getpos(self.expr))
            else:
                raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Mod):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as modulus arguments!",
                    self.expr,
                )
            if not are_units_compatible(left.typ, right.typ) and not (left.typ.unit or right.typ.unit):  # noqa: E501
                raise TypeMismatchException("Modulus arguments must have same unit", self.expr)
            new_unit = left.typ.unit or right.typ.unit
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    ['assert', right],
                    ['mod', left, right]
                ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr))
            elif ltyp == rtyp:
                o = LLLnode.from_list(
                    ['smod', left, ['clamp_nonzero', right]],
                    typ=BaseType(ltyp, new_unit),
                    pos=getpos(self.expr),
                )
            else:
                raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'")
        elif isinstance(self.expr.op, ast.Pow):
            if left.typ.positional or right.typ.positional:
                raise TypeMismatchException(
                    "Cannot use positional values as exponential arguments!",
                    self.expr,
                )
            if right.typ.unit:
                raise TypeMismatchException(
                    "Cannot use unit values as exponents",
                    self.expr,
                )
            if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, ast.Name):
                raise TypeMismatchException(
                    "Cannot use dynamic values as exponents, for unit base types",
                    self.expr,
                )
            if ltyp == rtyp == 'uint256':
                o = LLLnode.from_list([
                    'seq',
                    [
                        'assert',
                        [
                            'or',
                            ['or', ['eq', right, 1], ['iszero', right]],
                            ['lt', left, ['exp', left, right]]
                        ],
                    ],
                    ['exp', left, right],
                ], typ=BaseType('uint256'), pos=getpos(self.expr))
            elif ltyp == rtyp == 'int128':
                new_unit = left.typ.unit
                if left.typ.unit and not isinstance(self.expr.right, ast.Name):
                    new_unit = {left.typ.unit.copy().popitem()[0]: self.expr.right.n}
                o = LLLnode.from_list(
                    ['exp', left, right],
                    typ=BaseType('int128', new_unit),
                    pos=getpos(self.expr),
                )
            else:
                raise TypeMismatchException('Only whole number exponents are supported', self.expr)
        else:
            raise ParserException(f"Unsupported binary operator: {self.expr.op}", self.expr)

        p = ['seq']

        if pre_alloc_left:
            p.append(pre_alloc_left)
        if pre_alloc_right:
            p.append(pre_alloc_right)

        if o.typ.typ == 'int128':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINNUM],
                o,
                ['mload', MemoryPositions.MAXNUM],
            ])
            return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr))
        elif o.typ.typ == 'decimal':
            p.append([
                'clamp',
                ['mload', MemoryPositions.MINDECIMAL],
                o,
                ['mload', MemoryPositions.MAXDECIMAL],
            ])
            return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr))
        if o.typ.typ == 'uint256':
            p.append(o)
            return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr))
        else:
            raise Exception(f"{o} {o.typ}")