Beispiel #1
0
def test_version_check(evm_version):
    assert opcodes.version_check(begin=evm_version)
    assert opcodes.version_check(end=evm_version)
    assert opcodes.version_check(begin=evm_version, end=evm_version)
    if evm_version not in ("byzantium", "atlantis"):
        assert not opcodes.version_check(end="byzantium")
    istanbul_check = opcodes.version_check(begin="istanbul")
    assert istanbul_check == (opcodes.EVM_VERSIONS[evm_version] >=
                              opcodes.EVM_VERSIONS["istanbul"])
Beispiel #2
0
def sar(bits, x):
    if version_check(begin="constantinople"):
        return ["sar", bits, x]

    # emulate for older arches. keep in mind note from EIP 145:
    # "This is not equivalent to PUSH1 2 EXP SDIV, since it rounds
    # differently. See SDIV(-1, 2) == 0, while SAR(-1, 1) == -1."
    return ["sdiv", ["add", ["slt", x, 0], x], ["exp", 2, bits]]
Beispiel #3
0
def safe_div(x, y):
    num_info = x.typ._num_info
    typ = x.typ

    ok = [1]  # true

    if is_decimal_type(x.typ):
        lo, hi = num_info.bounds
        if max(abs(lo), abs(hi)) * num_info.divisor > 2**256 - 1:
            # stub to prevent us from adding fixed point numbers we don't know
            # how to deal with
            raise UnimplementedException(
                "safe_mul for decimal{num_info.bits}x{num_info.decimals}")
        x = ["mul", x, num_info.divisor]

    DIV = "sdiv" if num_info.is_signed else "div"
    res = IRnode.from_list([DIV, x, clamp("gt", y, 0)], typ=typ)
    with res.cache_when_complex("res") as (b1, res):

        # TODO: refactor this condition / push some things into the optimizer
        if num_info.is_signed and num_info.bits == 256:
            if version_check(begin="constantinople"):
                upper_bound = ["shl", 255, 1]
            else:
                upper_bound = -(2**255)

            if not x.is_literal and not y.typ.is_literal:
                ok = ["or", ["ne", y, ["not", 0]], ["ne", x, upper_bound]]
            # TODO push these rules into the optimizer
            elif x.is_literal and x.value == -(2**255):
                ok = ["ne", y, ["not", 0]]
            elif y.is_literal and y.value == -1:
                ok = ["ne", x, upper_bound]
            else:
                # x or y is a literal, and not an evil value.
                pass

        elif num_info.is_signed and is_integer_type(typ):
            lo, hi = num_info.bounds
            # we need to throw on min_value(typ) / -1,
            # but we can skip if one of the operands is a literal and not
            # the evil value
            can_skip_clamp = (x.is_literal
                              and x.value != lo) or (y.is_literal
                                                     and y.value != -1)
            if not can_skip_clamp:
                # clamp_basetype has fewer ops than the int256 rule.
                res = clamp_basetype(res)

        elif is_decimal_type(typ):
            # always clamp decimals, since decimal division can actually
            # result in something larger than either operand (e.g. 1.0 / 0.1)
            # TODO maybe use safe_mul
            res = clamp_basetype(res)

        check = IRnode.from_list(["assert", ok], error_msg="safemul")
        return IRnode.from_list(b1.resolve(["seq", check, res]))
Beispiel #4
0
def int128_clamp(lll_node):
    if version_check(begin="constantinople"):
        return [
            "with",
            "_val",
            lll_node,
            [
                "seq_unchecked",
                ["dup1", "_val"],
                ["if", ["slt", "_val", 0], ["not", "pass"]],
                ["assert", ["iszero", ["shr", 127, "pass"]]],
            ],
        ]
    else:
        return [
            "clamp",
            ["mload", MemoryPositions.MIN_INT128],
            lll_node,
            ["mload", MemoryPositions.MAX_INT128],
        ]
Beispiel #5
0
def get_nonreentrant_lock(func_type):
    if not func_type.nonreentrant:
        return ["pass"], ["pass"]

    nkey = func_type.reentrancy_key_position.position

    if version_check(begin="berlin"):
        # any nonzero values would work here (see pricing as of net gas
        # metering); these values are chosen so that downgrading to the
        # 0,1 scheme (if it is somehow necessary) is safe.
        final_value, temp_value = 3, 2
    else:
        final_value, temp_value = 0, 1

    check_notset = ["assert", ["ne", temp_value, ["sload", nkey]]]

    if func_type.mutability == StateMutability.VIEW:
        return [check_notset], [["seq"]]

    else:
        pre = ["seq", check_notset, ["sstore", nkey, temp_value]]
        post = ["sstore", nkey, final_value]
        return [pre], [post]
Beispiel #6
0
def safe_mul(x, y):
    # precondition: x.typ.typ == y.typ.typ
    num_info = x.typ._num_info

    # optimizer rules work better for the safemul checks below
    # if second operand is literal
    if x.is_literal:
        tmp = x
        x = y
        y = tmp

    res = IRnode.from_list(["mul", x, y], typ=x.typ.typ)

    DIV = "sdiv" if num_info.is_signed else "div"

    with res.cache_when_complex("ans") as (b1, res):

        ok = [1]  # True

        if num_info.bits > 128:  # check overflow mod 256
            # assert (res/y == x | y == 0)
            ok = ["or", ["eq", [DIV, res, y], x], ["iszero", y]]

        # int256
        if num_info.is_signed and num_info.bits == 256:
            # special case:
            # in the above sdiv check, if (r==-1 and l==-2**255),
            # -2**255<res> / -1<r> will return -2**255<l>.
            # need to check: not (r == -1 and l == -2**255)
            if version_check(begin="constantinople"):
                upper_bound = ["shl", 255, 1]
            else:
                upper_bound = -(2**255)

            check_x = ["ne", x, upper_bound]
            check_y = ["ne", ["not", y], 0]

            if not x.is_literal and not y.is_literal:
                # TODO can simplify this condition?
                ok = ["and", ok, ["or", check_x, check_y]]

            # TODO push some of this constant folding into optimizer
            elif x.is_literal and x.value == -(2**255):
                ok = ["and", ok, check_y]
            elif y.is_literal and y.value == -1:
                ok = ["and", ok, check_x]
            else:
                # x or y is a literal, and we have determined it is
                # not an evil value
                pass

        if is_decimal_type(res.typ):
            res = IRnode.from_list([DIV, res, num_info.divisor], typ=res.typ)

        # check overflow mod <bits>
        # NOTE: if 128 < bits < 256, `x * y` could be between
        # MAX_<bits> and 2**256 OR it could overflow past 2**256.
        # so, we check for overflow in mod 256 AS WELL AS mod <bits>
        # (if bits == 256, clamp_basetype is a no-op)
        res = clamp_basetype(res)

        check = IRnode.from_list(["assert", ok], error_msg="safediv")
        res = IRnode.from_list(["seq", check, res], typ=res.typ)

        return b1.resolve(res)
Beispiel #7
0
def shl(bits, x):
    if version_check(begin="constantinople"):
        return ["shl", bits, x]
    return ["mul", x, ["exp", 2, bits]]
Beispiel #8
0
    def parse_BinOp(self):
        left = Expr.parse_value_expr(self.expr.left, self.context)
        right = Expr.parse_value_expr(self.expr.right, self.context)

        if not is_numeric_type(left.typ) or not is_numeric_type(right.typ):
            return

        pos = getpos(self.expr)
        types = {left.typ.typ, right.typ.typ}
        literals = {left.typ.is_literal, right.typ.is_literal}

        # If one value of the operation is a literal, we recast it to match the non-literal type.
        # We know this is OK because types were already verified in the actual typechecking pass.
        # This is a temporary solution to not break codegen while we work toward removing types
        # altogether at this stage of complition. @iamdefinitelyahuman
        if literals == {True, False
                        } and len(types) > 1 and "decimal" not in types:
            if left.typ.is_literal and SizeLimits.in_bounds(
                    right.typ.typ, left.value):
                left = LLLnode.from_list(
                    left.value,
                    typ=BaseType(right.typ.typ, is_literal=True),
                    pos=pos,
                )
            elif right.typ.is_literal and SizeLimits.in_bounds(
                    left.typ.typ, right.value):
                right = LLLnode.from_list(
                    right.value,
                    typ=BaseType(left.typ.typ, is_literal=True),
                    pos=pos,
                )

        ltyp, rtyp = left.typ.typ, right.typ.typ

        # Sanity check - ensure that we aren't dealing with different types
        # This should be unreachable due to the type check pass
        assert ltyp == rtyp, "unreachable"

        arith = None
        if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)):
            new_typ = BaseType(ltyp)

            if ltyp == "uint256":
                if isinstance(self.expr.op, vy_ast.Add):
                    # safeadd
                    arith = [
                        "seq", ["assert", ["ge", ["add", "l", "r"], "l"]],
                        ["add", "l", "r"]
                    ]

                elif isinstance(self.expr.op, vy_ast.Sub):
                    # safesub
                    arith = [
                        "seq", ["assert", ["ge", "l", "r"]], ["sub", "l", "r"]
                    ]

            elif ltyp == "int256":
                if isinstance(self.expr.op, vy_ast.Add):
                    op, comp1, comp2 = "add", "sge", "slt"
                else:
                    op, comp1, comp2 = "sub", "sle", "sgt"

                if right.typ.is_literal:
                    if right.value >= 0:
                        arith = [
                            "seq", ["assert", [comp1, [op, "l", "r"], "l"]],
                            [op, "l", "r"]
                        ]
                    else:
                        arith = [
                            "seq", ["assert", [comp2, [op, "l", "r"], "l"]],
                            [op, "l", "r"]
                        ]
                else:
                    arith = [
                        "with",
                        "ans",
                        [op, "l", "r"],
                        [
                            "seq",
                            [
                                "assert",
                                [
                                    "or",
                                    [
                                        "and", ["sge", "r", 0],
                                        [comp1, "ans", "l"]
                                    ],
                                    [
                                        "and", ["slt", "r", 0],
                                        [comp2, "ans", "l"]
                                    ],
                                ],
                            ],
                            "ans",
                        ],
                    ]

            elif ltyp in ("decimal", "int128", "uint8"):
                op = "add" if isinstance(self.expr.op, vy_ast.Add) else "sub"
                arith = [op, "l", "r"]

        elif isinstance(self.expr.op, vy_ast.Mult):
            new_typ = BaseType(ltyp)
            if ltyp == "uint256":
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        [
                            "assert",
                            [
                                "or", ["eq", ["div", "ans", "l"], "r"],
                                ["iszero", "l"]
                            ]
                        ],
                        "ans",
                    ],
                ]

            elif ltyp == "int256":
                if version_check(begin="constantinople"):
                    upper_bound = ["shl", 255, 1]
                else:
                    upper_bound = -(2**255)
                if not left.typ.is_literal and not right.typ.is_literal:
                    bounds_check = [
                        "assert",
                        [
                            "or", ["ne", "l", ["not", 0]],
                            ["ne", "r", upper_bound]
                        ],
                    ]
                elif left.typ.is_literal and left.value == -1:
                    bounds_check = ["assert", ["ne", "r", upper_bound]]
                elif right.typ.is_literal and right.value == -(2**255):
                    bounds_check = ["assert", ["ne", "l", ["not", 0]]]
                else:
                    bounds_check = "pass"
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        bounds_check,
                        [
                            "assert",
                            [
                                "or", ["eq", ["sdiv", "ans", "l"], "r"],
                                ["iszero", "l"]
                            ]
                        ],
                        "ans",
                    ],
                ]

            elif ltyp in ("int128", "uint8"):
                arith = ["mul", "l", "r"]

            elif ltyp == "decimal":
                arith = [
                    "with",
                    "ans",
                    ["mul", "l", "r"],
                    [
                        "seq",
                        [
                            "assert",
                            [
                                "or", ["eq", ["sdiv", "ans", "l"], "r"],
                                ["iszero", "l"]
                            ]
                        ],
                        ["sdiv", "ans", DECIMAL_DIVISOR],
                    ],
                ]

        elif isinstance(self.expr.op, vy_ast.Div):
            if right.typ.is_literal and right.value == 0:
                return

            new_typ = BaseType(ltyp)

            if right.typ.is_literal:
                divisor = "r"
            else:
                # only apply the non-zero clamp when r is not a constant
                divisor = ["clamp_nonzero", "r"]

            if ltyp in ("uint8", "uint256"):
                arith = ["div", "l", divisor]

            elif ltyp == "int256":
                if version_check(begin="constantinople"):
                    upper_bound = ["shl", 255, 1]
                else:
                    upper_bound = -(2**255)
                if not left.typ.is_literal and not right.typ.is_literal:
                    bounds_check = [
                        "assert",
                        [
                            "or", ["ne", "r", ["not", 0]],
                            ["ne", "l", upper_bound]
                        ],
                    ]
                elif left.typ.is_literal and left.value == -(2**255):
                    bounds_check = ["assert", ["ne", "r", ["not", 0]]]
                elif right.typ.is_literal and right.value == -1:
                    bounds_check = ["assert", ["ne", "l", upper_bound]]
                else:
                    bounds_check = "pass"
                arith = ["seq", bounds_check, ["sdiv", "l", divisor]]

            elif ltyp == "int128":
                arith = ["sdiv", "l", divisor]

            elif ltyp == "decimal":
                arith = [
                    "sdiv",
                    ["mul", "l", DECIMAL_DIVISOR],
                    divisor,
                ]

        elif isinstance(self.expr.op, vy_ast.Mod):
            if right.typ.is_literal and right.value == 0:
                return

            new_typ = BaseType(ltyp)

            if right.typ.is_literal:
                divisor = "r"
            else:
                # only apply the non-zero clamp when r is not a constant
                divisor = ["clamp_nonzero", "r"]

            if ltyp in ("uint8", "uint256"):
                arith = ["mod", "l", divisor]
            else:
                arith = ["smod", "l", divisor]

        elif isinstance(self.expr.op, vy_ast.Pow):
            new_typ = BaseType(ltyp)

            if self.expr.left.get("value") == 1:
                return LLLnode.from_list([1], typ=new_typ, pos=pos)
            if self.expr.left.get("value") == 0:
                return LLLnode.from_list(["iszero", right],
                                         typ=new_typ,
                                         pos=pos)

            if ltyp == "int128":
                is_signed = True
                num_bits = 128
            elif ltyp == "int256":
                is_signed = True
                num_bits = 256
            elif ltyp == "uint8":
                is_signed = False
                num_bits = 8
            else:
                is_signed = False
                num_bits = 256

            if isinstance(self.expr.left, vy_ast.Int):
                value = self.expr.left.value
                upper_bound = calculate_largest_power(value, num_bits,
                                                      is_signed) + 1
                # for signed integers, this also prevents negative values
                clamp = ["lt", right, upper_bound]
                return LLLnode.from_list(
                    ["seq", ["assert", clamp], ["exp", left, right]],
                    typ=new_typ,
                    pos=pos,
                )
            elif isinstance(self.expr.right, vy_ast.Int):
                value = self.expr.right.value
                upper_bound = calculate_largest_base(value, num_bits,
                                                     is_signed) + 1
                if is_signed:
                    clamp = [
                        "and", ["slt", left, upper_bound],
                        ["sgt", left, -upper_bound]
                    ]
                else:
                    clamp = ["lt", left, upper_bound]
                return LLLnode.from_list(
                    ["seq", ["assert", clamp], ["exp", left, right]],
                    typ=new_typ,
                    pos=pos,
                )
            else:
                # `a ** b` where neither `a` or `b` are known
                # TODO this is currently unreachable, once we implement a way to do it safely
                # remove the check in `vyper/context/types/value/numeric.py`
                return

        if arith is None:
            return

        arith = LLLnode.from_list(arith, typ=new_typ)

        p = [
            "with",
            "l",
            left,
            [
                "with",
                "r",
                right,
                # note clamp_basetype is a noop on [u]int256
                # note: clamp_basetype throws on unclampable input
                clamp_basetype(arith),
            ],
        ]
        return LLLnode.from_list(p, typ=new_typ, pos=pos)
Beispiel #9
0
 def parse_Attribute(self):
     # x.balance: balance of address x
     if self.expr.attr == "balance":
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if is_base_type(addr.typ, "address"):
             if (isinstance(self.expr.value, vy_ast.Name)
                     and self.expr.value.id == "self"
                     and version_check(begin="istanbul")):
                 seq = ["selfbalance"]
             else:
                 seq = ["balance", addr]
             return LLLnode.from_list(
                 seq,
                 typ=BaseType("uint256"),
                 location=None,
                 pos=getpos(self.expr),
             )
     # x.codesize: codesize of address x
     elif self.expr.attr == "codesize" or self.expr.attr == "is_contract":
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if is_base_type(addr.typ, "address"):
             if self.expr.attr == "codesize":
                 if self.expr.value.id == "self":
                     eval_code = ["codesize"]
                 else:
                     eval_code = ["extcodesize", addr]
                 output_type = "uint256"
             else:
                 eval_code = ["gt", ["extcodesize", addr], 0]
                 output_type = "bool"
             return LLLnode.from_list(
                 eval_code,
                 typ=BaseType(output_type),
                 location=None,
                 pos=getpos(self.expr),
             )
     # x.codehash: keccak of address x
     elif self.expr.attr == "codehash":
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if not version_check(begin="constantinople"):
             raise EvmVersionException(
                 "address.codehash is unavailable prior to constantinople ruleset",
                 self.expr)
         if is_base_type(addr.typ, "address"):
             return LLLnode.from_list(
                 ["extcodehash", addr],
                 typ=BaseType("bytes32"),
                 location=None,
                 pos=getpos(self.expr),
             )
     # x.code: codecopy/extcodecopy of address x
     elif self.expr.attr == "code":
         addr = Expr.parse_value_expr(self.expr.value, self.context)
         if is_base_type(addr.typ, "address"):
             # These adhoc nodes will be replaced with a valid node in `Slice.build_LLL`
             if addr.value == "address":  # for `self.code`
                 return LLLnode.from_list(["~selfcode"],
                                          typ=ByteArrayType(0))
             return LLLnode.from_list(["~extcode", addr],
                                      typ=ByteArrayType(0))
     # self.x: global attribute
     elif isinstance(self.expr.value,
                     vy_ast.Name) and self.expr.value.id == "self":
         type_ = self.expr._metadata["type"]
         var = self.context.globals[self.expr.attr]
         return LLLnode.from_list(
             type_.position.position,
             typ=var.typ,
             location="storage",
             pos=getpos(self.expr),
             annotation="self." + self.expr.attr,
         )
     # Reserved keywords
     elif (isinstance(self.expr.value, vy_ast.Name)
           and self.expr.value.id in ENVIRONMENT_VARIABLES):
         key = f"{self.expr.value.id}.{self.expr.attr}"
         if key == "msg.sender":
             return LLLnode.from_list(["caller"],
                                      typ="address",
                                      pos=getpos(self.expr))
         elif key == "msg.data":
             # This adhoc node will be replaced with a valid node in `Slice/Len.build_LLL`
             return LLLnode.from_list(["~calldata"], typ=ByteArrayType(0))
         elif key == "msg.value" and self.context.is_payable:
             return LLLnode.from_list(
                 ["callvalue"],
                 typ=BaseType("uint256"),
                 pos=getpos(self.expr),
             )
         elif key == "msg.gas":
             return LLLnode.from_list(
                 ["gas"],
                 typ="uint256",
                 pos=getpos(self.expr),
             )
         elif key == "block.difficulty":
             return LLLnode.from_list(
                 ["difficulty"],
                 typ="uint256",
                 pos=getpos(self.expr),
             )
         elif key == "block.timestamp":
             return LLLnode.from_list(
                 ["timestamp"],
                 typ=BaseType("uint256"),
                 pos=getpos(self.expr),
             )
         elif key == "block.coinbase":
             return LLLnode.from_list(["coinbase"],
                                      typ="address",
                                      pos=getpos(self.expr))
         elif key == "block.number":
             return LLLnode.from_list(["number"],
                                      typ="uint256",
                                      pos=getpos(self.expr))
         elif key == "block.gaslimit":
             return LLLnode.from_list(["gaslimit"],
                                      typ="uint256",
                                      pos=getpos(self.expr))
         elif key == "block.basefee":
             return LLLnode.from_list(["basefee"],
                                      typ="uint256",
                                      pos=getpos(self.expr))
         elif key == "block.prevhash":
             return LLLnode.from_list(
                 ["blockhash", ["sub", "number", 1]],
                 typ="bytes32",
                 pos=getpos(self.expr),
             )
         elif key == "tx.origin":
             return LLLnode.from_list(["origin"],
                                      typ="address",
                                      pos=getpos(self.expr))
         elif key == "tx.gasprice":
             return LLLnode.from_list(["gasprice"],
                                      typ="uint256",
                                      pos=getpos(self.expr))
         elif key == "chain.id":
             if not version_check(begin="istanbul"):
                 raise EvmVersionException(
                     "chain.id is unavailable prior to istanbul ruleset",
                     self.expr)
             return LLLnode.from_list(["chainid"],
                                      typ="uint256",
                                      pos=getpos(self.expr))
     # Other variables
     else:
         sub = Expr(self.expr.value, self.context).lll_node
         # contract type
         if isinstance(sub.typ, InterfaceType):
             return sub
         if isinstance(sub.typ,
                       StructType) and self.expr.attr in sub.typ.members:
             return get_element_ptr(sub,
                                    self.expr.attr,
                                    pos=getpos(self.expr))
Beispiel #10
0
def _optimize_binop(binop, args, ann, parent_op):
    fn, symb, unsigned = arith[binop]

    # local version of _evm_int which defaults to the current binop's signedness
    def _int(x, unsigned=unsigned):
        return _evm_int(x, unsigned=unsigned)

    def _wrap(x):
        return _wrap256(x, unsigned=unsigned)

    new_ann = None
    if ann is not None:
        l_ann = _shorten_annotation(args[0].annotation or str(args[0]))
        r_ann = _shorten_annotation(args[1].annotation or str(args[1]))
        new_ann = l_ann + symb + r_ann
        new_ann = f"{ann} ({new_ann})"

    def finalize(new_val, new_args):
        # if the original had side effects which might not be in the
        # optimized output, roll back the optimization
        rollback = (args[0].is_complex_ir and not _deep_contains(
            new_args, args[0])) or (args[1].is_complex_ir
                                    and not _deep_contains(new_args, args[1]))

        if rollback:
            return None

        return new_val, new_args, new_ann

    if _is_int(args[0]) and _is_int(args[1]):
        # compile-time arithmetic
        left, right = _int(args[0]), _int(args[1])
        new_val = fn(left, right)
        # wrap the result, since `fn` generally does not wrap.
        # (note: do not rely on wrapping/non-wrapping behavior for `fn`!
        # some ops, like evm_pow, ALWAYS wrap).
        new_val = _wrap(new_val)
        return finalize(new_val, [])

    # we can return truthy values instead of actual math
    is_truthy = parent_op in {"if", "assert", "iszero"}

    def _conservative_eq(x, y):
        # whether x evaluates to the same value as y at runtime.
        # TODO we can do better than this check, but we need to be
        # conservative in case x has side effects.
        return x.args == y.args == [] and x.value == y.value and not x.is_complex_ir

    ##
    # ARITHMETIC AND BITWISE OPS
    ##

    # for commutative ops, move the literal to the second
    # position to make the later logic cleaner
    if binop in COMMUTATIVE_OPS and _is_int(args[0]):
        args = [args[1], args[0]]

    if binop in {"add", "sub", "xor", "or"} and _int(args[1]) == 0:
        # x + 0 == x - 0 == x | 0 == x ^ 0 == x
        return finalize("seq", [args[0]])

    if binop in {"sub", "xor", "ne"} and _conservative_eq(args[0], args[1]):
        # x - x == x ^ x == x != x == 0
        return finalize(0, [])

    if binop == "eq" and _conservative_eq(args[0], args[1]):
        # (x == x) == 1
        return finalize(1, [])

    # TODO associativity rules

    # x * 0 == x / 0 == x % 0 == x & 0 == 0
    if binop in {"mul", "div", "sdiv", "mod", "smod", "and"} and _int(
            args[1]) == 0:
        return finalize(0, [])

    # x % 1 == 0
    if binop in {"mod", "smod"} and _int(args[1]) == 1:
        return finalize(0, [])

    # x * 1 == x / 1 == x
    if binop in {"mul", "div", "sdiv"} and _int(args[1]) == 1:
        return finalize("seq", [args[0]])

    # x * -1 == 0 - x
    if binop in {"mul", "sdiv"} and _int(args[1], SIGNED) == -1:
        return finalize("sub", [0, args[0]])

    if binop in {"and", "or", "xor"} and _int(args[1], SIGNED) == -1:
        assert unsigned == UNSIGNED
        if binop == "and":
            # -1 & x == x
            return finalize("seq", [args[0]])

        if binop == "xor":
            # -1 ^ x == ~x
            return finalize("not", [args[0]])

        if binop == "or":
            # -1 | x == -1
            return finalize(args[1].value, [])

        raise CompilerPanic("unreachable")  # pragma: notest

    # -1 - x == ~x (definition of two's complement)
    if binop == "sub" and _int(args[0], SIGNED) == -1:
        return finalize("not", [args[1]])

    if binop == "exp":
        # n ** 0 == 1 (forall n)
        # 1 ** n == 1
        if _int(args[1]) == 0 or _int(args[0]) == 1:
            return finalize(1, [])
        # 0 ** n == (1 if n == 0 else 0)
        if _int(args[0]) == 0:
            return finalize("iszero", [args[1]])
        # n ** 1 == n
        if _int(args[1]) == 1:
            return finalize("seq", [args[0]])

    # TODO: check me! reduce codesize for negative numbers
    # if binop in {"add", "sub"} and _int(args[1], SIGNED) < 0:
    #     flipped = "add" if binop == "sub" else "sub"
    #     return finalize(flipped, [args[0], -args[1]])

    # TODO maybe OK:
    # elif binop == "div" and _int(args[1], UNSIGNED) == MAX_UINT256:
    #    # (div x (2**256 - 1)) == (eq x (2**256 - 1))
    #    new_val = "eq"
    #    args = args

    if binop in {"mod", "div", "mul"} and _is_int(args[1]) and is_power_of_two(
            _int(args[1])):
        assert unsigned == UNSIGNED, "something's not right."
        # shave two gas off mod/div/mul for powers of two
        # x % 2**n == x & (2**n - 1)
        if binop == "mod":
            return finalize("and", [args[0], _int(args[1]) - 1])

        if binop == "div" and version_check(begin="constantinople"):
            # x / 2**n == x >> n
            # recall shr/shl have unintuitive arg order
            return finalize("shr", [int_log2(_int(args[1])), args[0]])

        # note: no rule for sdiv since it rounds differently from sar
        if binop == "mul" and version_check(begin="constantinople"):
            # x * 2**n == x << n
            return finalize("shl", [int_log2(_int(args[1])), args[0]])

        # reachable but only before constantinople
        if version_check(begin="constantinople"):  # pragma: no cover
            raise CompilerPanic("unreachable")

    ##
    # COMPARISONS
    ##

    if binop == "eq" and _int(args[1]) == 0:
        return finalize("iszero", [args[0]])

    # can't improve gas but can improve codesize
    if binop == "ne" and _int(args[1]) == 0:
        return finalize("iszero", [["iszero", args[0]]])

    if binop == "eq" and _int(args[1], SIGNED) == -1:
        # equal gas, but better codesize
        # x == MAX_UINT256 => ~x == 0
        return finalize("iszero", [["not", args[0]]])

    # note: in places where truthy is accepted, sequences of
    # ISZERO ISZERO will be optimized out, so we try to rewrite
    # some operations to include iszero
    # (note ordering; truthy optimizations should come first
    # to avoid getting clobbered by other branches)
    if is_truthy:
        if binop == "eq":
            assert unsigned == UNSIGNED
            # (eq x y) has the same truthyness as (iszero (xor x y))
            # it also has the same truthyness as (iszero (sub x y)),
            # but xor is slightly easier to optimize because of being
            # commutative.
            # note that (xor (-1) x) has its own rule
            return finalize("iszero", [["xor", args[0], args[1]]])

        if binop == "ne":
            # trigger other optimizations
            return finalize("iszero", [["eq", *args]])

        # TODO can we do this?
        # if val == "div":
        #     return finalize("gt", ["iszero", args])

        if binop == "or" and _is_int(args[1]) and _int(args[1]) != 0:
            # (x | y != 0) for any (y != 0)
            return finalize(1, [])

    if binop in COMPARISON_OPS:
        prefer_strict = not is_truthy
        res = _comparison_helper(binop, args, prefer_strict=prefer_strict)
        if res is None:
            return res
        new_op, new_args = res
        return finalize(new_op, new_args)

    # no optimization happened
    return None
Beispiel #11
0
def sar(bits, x):
    if version_check(begin="constantinople"):
        return ["sar", bits, x]

    raise NotImplementedError("no SAR emulation for pre-constantinople EVM")
Beispiel #12
0
def make_arg_clamper(datapos, mempos, typ, is_init=False):
    """
    Clamps argument to type limits.

    Arguments
    ---------
    datapos : int | LLLnode
        Calldata offset of the value being clamped
    mempos : int | LLLnode
        Memory offset that the value is stored at during clamping
    typ : vyper.types.types.BaseType
        Type of the value
    is_init : bool, optional
        Boolean indicating if we are generating init bytecode

    Returns
    -------
    LLLnode
        Arg clamper LLL
    """

    if not is_init:
        data_decl = ["calldataload", ["add", 4, datapos]]
        copier = functools.partial(_mk_calldatacopy_copier, mempos=mempos)
    else:
        data_decl = ["codeload", ["add", "~codelen", datapos]]
        copier = functools.partial(_mk_codecopy_copier, mempos=mempos)
    # Numbers: make sure they're in range
    if is_base_type(typ, "int128"):
        return LLLnode.from_list(int128_clamp(data_decl),
                                 typ=typ,
                                 annotation="checking int128 input")
    # Booleans: make sure they're zero or one
    elif is_base_type(typ, "bool"):
        if version_check(begin="constantinople"):
            lll = ["assert", ["iszero", ["shr", 1, data_decl]]]
        else:
            lll = ["uclamplt", data_decl, 2]
        return LLLnode.from_list(lll,
                                 typ=typ,
                                 annotation="checking bool input")
    # Addresses: make sure they're in range
    elif is_base_type(typ, "address"):
        return LLLnode.from_list(address_clamp(data_decl),
                                 typ=typ,
                                 annotation="checking address input")
    # Bytes: make sure they have the right size
    elif isinstance(typ, ByteArrayLike):
        return LLLnode.from_list(
            [
                "seq",
                copier(data_decl, 32 + typ.maxlen),
                [
                    "assert",
                    [
                        "le", ["calldataload", ["add", 4, data_decl]],
                        typ.maxlen
                    ]
                ],
            ],
            typ=None,
            annotation="checking bytearray input",
        )
    # Lists: recurse
    elif isinstance(typ, ListType):
        if typ.count > 5 or (type(datapos) is list and type(mempos) is list):
            # find ultimate base type
            subtype = typ.subtype
            while hasattr(subtype, "subtype"):
                subtype = subtype.subtype

            # make arg clamper for the base type
            offset = MemoryPositions.FREE_LOOP_INDEX
            clamper = make_arg_clamper(
                ["add", datapos, ["mload", offset]],
                ["add", mempos, ["mload", offset]],
                subtype,
                is_init,
            )
            if clamper.value == "pass":
                # no point looping if the base type doesn't require clamping
                return clamper

            # loop the entire array at once, even if it's multidimensional
            type_size = get_size_of_type(typ)
            i_incr = get_size_of_type(subtype) * 32

            mem_to = type_size * 32
            loop_label = f"_check_list_loop_{str(uuid.uuid4())}"

            lll_node = [
                ["mstore", offset, 0],  # init loop
                ["label", loop_label],
                clamper,
                ["mstore", offset, ["add", ["mload", offset], i_incr]],
                [
                    "if", ["lt", ["mload", offset], mem_to],
                    ["goto", loop_label]
                ],
            ]
        else:
            lll_node = []
            for i in range(typ.count):
                offset = get_size_of_type(typ.subtype) * 32 * i
                lll_node.append(
                    make_arg_clamper(datapos + offset, mempos + offset,
                                     typ.subtype, is_init))
        return LLLnode.from_list(["seq"] + lll_node,
                                 typ=None,
                                 annotation="checking list input")
    # Otherwise don't make any checks
    else:
        return LLLnode.from_list("pass")
Beispiel #13
0
def address_clamp(lll_node):
    if version_check(begin="constantinople"):
        return ["assert", ["iszero", ["shr", 160, lll_node]]]
    else:
        return ["uclamplt", lll_node, ["mload", MemoryPositions.ADDRSIZE]]
Beispiel #14
0
def to_int256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("int256", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("int256", ),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("int256", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("int256"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type == "int128":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "uint256":
        if version_check(begin="constantinople"):
            upper_bound = ["shl", 255, 1]
        else:
            upper_bound = -(2**255)
        return LLLnode.from_list(["uclamplt", in_arg, upper_bound],
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "decimal":
        return LLLnode.from_list(
            ["sdiv", in_arg, DECIMAL_DIVISOR],
            typ=BaseType("int256"),
            pos=getpos(expr),
        )

    elif isinstance(in_arg, LLLnode) and input_type == "bool":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("bytes32", "address"):
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType("int256"),
                       pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("Bytes", "String"):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, "int256")

    else:
        raise InvalidLiteral(f"Invalid input for int256: {in_arg}", expr)
Beispiel #15
0
def test_version_check_no_begin_or_end():
    with pytest.raises(CompilerPanic):
        opcodes.version_check()