def safe_div(x, y): num_info = x.typ._num_info typ = x.typ ok = [1] # true if is_decimal_type(x.typ): lo, hi = num_info.bounds if max(abs(lo), abs(hi)) * num_info.divisor > 2**256 - 1: # stub to prevent us from adding fixed point numbers we don't know # how to deal with raise UnimplementedException( "safe_mul for decimal{num_info.bits}x{num_info.decimals}") x = ["mul", x, num_info.divisor] DIV = "sdiv" if num_info.is_signed else "div" res = IRnode.from_list([DIV, x, clamp("gt", y, 0)], typ=typ) with res.cache_when_complex("res") as (b1, res): # TODO: refactor this condition / push some things into the optimizer if num_info.is_signed and num_info.bits == 256: if version_check(begin="constantinople"): upper_bound = ["shl", 255, 1] else: upper_bound = -(2**255) if not x.is_literal and not y.typ.is_literal: ok = ["or", ["ne", y, ["not", 0]], ["ne", x, upper_bound]] # TODO push these rules into the optimizer elif x.is_literal and x.value == -(2**255): ok = ["ne", y, ["not", 0]] elif y.is_literal and y.value == -1: ok = ["ne", x, upper_bound] else: # x or y is a literal, and not an evil value. pass elif num_info.is_signed and is_integer_type(typ): lo, hi = num_info.bounds # we need to throw on min_value(typ) / -1, # but we can skip if one of the operands is a literal and not # the evil value can_skip_clamp = (x.is_literal and x.value != lo) or (y.is_literal and y.value != -1) if not can_skip_clamp: # clamp_basetype has fewer ops than the int256 rule. res = clamp_basetype(res) elif is_decimal_type(typ): # always clamp decimals, since decimal division can actually # result in something larger than either operand (e.g. 1.0 / 0.1) # TODO maybe use safe_mul res = clamp_basetype(res) check = IRnode.from_list(["assert", ok], error_msg="safemul") return IRnode.from_list(b1.resolve(["seq", check, res]))
def safe_sub(x, y): num_info = x.typ._num_info res = IRnode.from_list(["sub", x, y], typ=x.typ.typ) if num_info.bits < 256: return clamp_basetype(res) # bits == 256 with res.cache_when_complex("ans") as (b1, res): if num_info.is_signed: # if r < 0: # ans > l # else: # ans <= l # aka (iszero (ans > l)) # aka: (r < 0) == (ans > l) ok = ["eq", ["slt", y, 0], ["sgt", res, x]] else: # note this is "equivalent" to the unsigned form # of the above (because y < 0 == False) # ["eq", ["lt", y, 0], ["gt", res, x]] # TODO push down into optimizer rules. ok = ["le", res, x] check = IRnode.from_list(["assert", ok], error_msg="safesub") ret = IRnode.from_list(["seq", check, res]) return b1.resolve(ret)
def safe_add(x, y): assert x.typ is not None and x.typ == y.typ and isinstance(x.typ, BaseType) num_info = x.typ._num_info res = IRnode.from_list(["add", x, y], typ=x.typ.typ) if num_info.bits < 256: return clamp_basetype(res) # bits == 256 with res.cache_when_complex("ans") as (b1, res): if num_info.is_signed: # if r < 0: # ans < l # else: # ans >= l # aka (iszero (ans < l)) # aka: (r < 0) == (ans < l) ok = ["eq", ["slt", y, 0], ["slt", res, x]] else: # note this is "equivalent" to the unsigned form # of the above (because y < 0 == False) # ["eq", ["lt", y, 0], ["lt", res, x]] # TODO push down into optimizer rules. ok = ["ge", res, x] ret = IRnode.from_list(["seq", ["assert", ok], res]) return b1.resolve(ret)
def to_decimal(expr, arg, out_typ): # question: is converting from Bytes to decimal allowed? _check_bytes(expr, arg, out_typ, max_bytes_allowed=16) if isinstance(expr, vy_ast.Constant): return _literal_decimal(expr, out_typ) if isinstance(arg.typ, ByteArrayType): arg_typ = arg.typ arg = _bytes_to_num(arg, out_typ, signed=True) # TODO revisit this condition once we have more decimal types # and decimal bounds expand # will be something like: if info.m_bits > 168 if arg_typ.maxlen * 8 > 128: arg = IRnode.from_list(arg, typ=out_typ) arg = clamp_basetype(arg) return IRnode.from_list(arg, typ=out_typ) elif is_bytes_m_type(arg.typ): info = arg.typ._bytes_info arg = _bytes_to_num(arg, out_typ, signed=True) # TODO revisit this condition once we have more decimal types # and decimal bounds expand # will be something like: if info.m_bits > 168 if info.m_bits > 128: arg = IRnode.from_list(arg, typ=out_typ) arg = clamp_basetype(arg) return IRnode.from_list(arg, typ=out_typ) elif is_integer_type(arg.typ): int_info = arg.typ._int_info arg = _int_to_fixed(arg, out_typ) out_info = out_typ._decimal_info if int_info.bits > out_info.bits: # TODO: _num_clamp probably not necessary bc already # clamped in _int_to_fixed arg = _num_clamp(arg, out_info, int_info) return IRnode.from_list(arg, typ=out_typ) elif is_base_type(arg.typ, "bool"): arg = _int_to_fixed(arg, out_typ) return IRnode.from_list(arg, typ=out_typ) else: raise CompilerPanic("unreachable") # pragma: notest
def to_decimal(expr, arg, out_typ): _check_bytes(expr, arg, out_typ, 32) out_info = out_typ._decimal_info if isinstance(expr, vy_ast.Constant): return _literal_decimal(expr, arg.typ, out_typ) if isinstance(arg.typ, ByteArrayType): arg_typ = arg.typ arg = _bytes_to_num(arg, out_typ, signed=True) if arg_typ.maxlen * 8 > 168: arg = IRnode.from_list(arg, typ=out_typ) arg = clamp_basetype(arg) return IRnode.from_list(arg, typ=out_typ) elif is_bytes_m_type(arg.typ): info = arg.typ._bytes_info arg = _bytes_to_num(arg, out_typ, signed=True) if info.m_bits > 168: arg = IRnode.from_list(arg, typ=out_typ) arg = clamp_basetype(arg) return IRnode.from_list(arg, typ=out_typ) elif is_integer_type(arg.typ): arg = _int_to_fixed(arg, out_typ) return IRnode.from_list(arg, typ=out_typ) elif is_base_type(arg.typ, "bool"): # TODO: consider adding _int_info to bool so we can use _int_to_fixed arg = ["mul", arg, 10**out_info.decimals] return IRnode.from_list(arg, typ=out_typ) else: raise CompilerPanic("unreachable") # pragma: notest
def to_uint8(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "Bytes": if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to uint8", expr, ) else: # uint8 clamp is already applied in byte_array_to_num in_arg = byte_array_to_num(in_arg, "uint8") else: # cast to output type so clamp_basetype works in_arg = LLLnode.from_list(in_arg, typ="uint8") return LLLnode.from_list(clamp_basetype(in_arg), typ=BaseType("uint8"), pos=getpos(expr))
def byte_array_to_num(arg, out_type): """ Takes a <32 byte array as input, and outputs a number. """ # the location of the bytestring bs_start = (LLLnode.from_list( "bs_start", typ=arg.typ, location=arg.location, encoding=arg.encoding) if arg.is_complex_lll else arg) if arg.location == "storage": len_ = get_bytearray_length(bs_start) data = LLLnode.from_list(["sload", add_ofst(bs_start, 1)], typ=BaseType("int256")) else: op = load_op(arg.location) len_ = LLLnode.from_list([op, bs_start], typ=BaseType("int256")) data = LLLnode.from_list([op, add_ofst(bs_start, 32)], typ=BaseType("int256")) # converting a bytestring to a number: # bytestring is right-padded with zeroes, int is left-padded. # convert by shr the number of zero bytes (converted to bits) # e.g. "abcd000000000000" -> bitcast(000000000000abcd, output_type) num_zero_bits = ["mul", 8, ["sub", 32, "len_"]] bitcasted = LLLnode.from_list(shr(num_zero_bits, "val"), typ=out_type) result = clamp_basetype(bitcasted) # TODO use cache_when_complex for these `with` values ret = ["with", "val", data, ["with", "len_", len_, result]] if arg.is_complex_lll: ret = ["with", "bs_start", arg, ret] return LLLnode.from_list( ret, typ=BaseType(out_type), annotation=f"__intrinsic__byte_array_to_num({out_type})", )
def safe_mul(x, y): # precondition: x.typ.typ == y.typ.typ num_info = x.typ._num_info # optimizer rules work better for the safemul checks below # if second operand is literal if x.is_literal: tmp = x x = y y = tmp res = IRnode.from_list(["mul", x, y], typ=x.typ.typ) DIV = "sdiv" if num_info.is_signed else "div" with res.cache_when_complex("ans") as (b1, res): ok = [1] # True if num_info.bits > 128: # check overflow mod 256 # assert (res/y == x | y == 0) ok = ["or", ["eq", [DIV, res, y], x], ["iszero", y]] # int256 if num_info.is_signed and num_info.bits == 256: # special case: # in the above sdiv check, if (r==-1 and l==-2**255), # -2**255<res> / -1<r> will return -2**255<l>. # need to check: not (r == -1 and l == -2**255) if version_check(begin="constantinople"): upper_bound = ["shl", 255, 1] else: upper_bound = -(2**255) check_x = ["ne", x, upper_bound] check_y = ["ne", ["not", y], 0] if not x.is_literal and not y.is_literal: # TODO can simplify this condition? ok = ["and", ok, ["or", check_x, check_y]] # TODO push some of this constant folding into optimizer elif x.is_literal and x.value == -(2**255): ok = ["and", ok, check_y] elif y.is_literal and y.value == -1: ok = ["and", ok, check_x] else: # x or y is a literal, and we have determined it is # not an evil value pass if is_decimal_type(res.typ): res = IRnode.from_list([DIV, res, num_info.divisor], typ=res.typ) # check overflow mod <bits> # NOTE: if 128 < bits < 256, `x * y` could be between # MAX_<bits> and 2**256 OR it could overflow past 2**256. # so, we check for overflow in mod 256 AS WELL AS mod <bits> # (if bits == 256, clamp_basetype is a no-op) res = clamp_basetype(res) check = IRnode.from_list(["assert", ok], error_msg="safediv") res = IRnode.from_list(["seq", check, res], typ=res.typ) return b1.resolve(res)
def parse_BinOp(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.right, self.context) if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): return pos = getpos(self.expr) types = {left.typ.typ, right.typ.typ} literals = {left.typ.is_literal, right.typ.is_literal} # If one value of the operation is a literal, we recast it to match the non-literal type. # We know this is OK because types were already verified in the actual typechecking pass. # This is a temporary solution to not break codegen while we work toward removing types # altogether at this stage of complition. @iamdefinitelyahuman if literals == {True, False } and len(types) > 1 and "decimal" not in types: if left.typ.is_literal and SizeLimits.in_bounds( right.typ.typ, left.value): left = LLLnode.from_list( left.value, typ=BaseType(right.typ.typ, is_literal=True), pos=pos, ) elif right.typ.is_literal and SizeLimits.in_bounds( left.typ.typ, right.value): right = LLLnode.from_list( right.value, typ=BaseType(left.typ.typ, is_literal=True), pos=pos, ) ltyp, rtyp = left.typ.typ, right.typ.typ # Sanity check - ensure that we aren't dealing with different types # This should be unreachable due to the type check pass assert ltyp == rtyp, "unreachable" arith = None if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)): new_typ = BaseType(ltyp) if ltyp == "uint256": if isinstance(self.expr.op, vy_ast.Add): # safeadd arith = [ "seq", ["assert", ["ge", ["add", "l", "r"], "l"]], ["add", "l", "r"] ] elif isinstance(self.expr.op, vy_ast.Sub): # safesub arith = [ "seq", ["assert", ["ge", "l", "r"]], ["sub", "l", "r"] ] elif ltyp == "int256": if isinstance(self.expr.op, vy_ast.Add): op, comp1, comp2 = "add", "sge", "slt" else: op, comp1, comp2 = "sub", "sle", "sgt" if right.typ.is_literal: if right.value >= 0: arith = [ "seq", ["assert", [comp1, [op, "l", "r"], "l"]], [op, "l", "r"] ] else: arith = [ "seq", ["assert", [comp2, [op, "l", "r"], "l"]], [op, "l", "r"] ] else: arith = [ "with", "ans", [op, "l", "r"], [ "seq", [ "assert", [ "or", [ "and", ["sge", "r", 0], [comp1, "ans", "l"] ], [ "and", ["slt", "r", 0], [comp2, "ans", "l"] ], ], ], "ans", ], ] elif ltyp in ("decimal", "int128", "uint8"): op = "add" if isinstance(self.expr.op, vy_ast.Add) else "sub" arith = [op, "l", "r"] elif isinstance(self.expr.op, vy_ast.Mult): new_typ = BaseType(ltyp) if ltyp == "uint256": arith = [ "with", "ans", ["mul", "l", "r"], [ "seq", [ "assert", [ "or", ["eq", ["div", "ans", "l"], "r"], ["iszero", "l"] ] ], "ans", ], ] elif ltyp == "int256": if version_check(begin="constantinople"): upper_bound = ["shl", 255, 1] else: upper_bound = -(2**255) if not left.typ.is_literal and not right.typ.is_literal: bounds_check = [ "assert", [ "or", ["ne", "l", ["not", 0]], ["ne", "r", upper_bound] ], ] elif left.typ.is_literal and left.value == -1: bounds_check = ["assert", ["ne", "r", upper_bound]] elif right.typ.is_literal and right.value == -(2**255): bounds_check = ["assert", ["ne", "l", ["not", 0]]] else: bounds_check = "pass" arith = [ "with", "ans", ["mul", "l", "r"], [ "seq", bounds_check, [ "assert", [ "or", ["eq", ["sdiv", "ans", "l"], "r"], ["iszero", "l"] ] ], "ans", ], ] elif ltyp in ("int128", "uint8"): arith = ["mul", "l", "r"] elif ltyp == "decimal": arith = [ "with", "ans", ["mul", "l", "r"], [ "seq", [ "assert", [ "or", ["eq", ["sdiv", "ans", "l"], "r"], ["iszero", "l"] ] ], ["sdiv", "ans", DECIMAL_DIVISOR], ], ] elif isinstance(self.expr.op, vy_ast.Div): if right.typ.is_literal and right.value == 0: return new_typ = BaseType(ltyp) if right.typ.is_literal: divisor = "r" else: # only apply the non-zero clamp when r is not a constant divisor = ["clamp_nonzero", "r"] if ltyp in ("uint8", "uint256"): arith = ["div", "l", divisor] elif ltyp == "int256": if version_check(begin="constantinople"): upper_bound = ["shl", 255, 1] else: upper_bound = -(2**255) if not left.typ.is_literal and not right.typ.is_literal: bounds_check = [ "assert", [ "or", ["ne", "r", ["not", 0]], ["ne", "l", upper_bound] ], ] elif left.typ.is_literal and left.value == -(2**255): bounds_check = ["assert", ["ne", "r", ["not", 0]]] elif right.typ.is_literal and right.value == -1: bounds_check = ["assert", ["ne", "l", upper_bound]] else: bounds_check = "pass" arith = ["seq", bounds_check, ["sdiv", "l", divisor]] elif ltyp == "int128": arith = ["sdiv", "l", divisor] elif ltyp == "decimal": arith = [ "sdiv", ["mul", "l", DECIMAL_DIVISOR], divisor, ] elif isinstance(self.expr.op, vy_ast.Mod): if right.typ.is_literal and right.value == 0: return new_typ = BaseType(ltyp) if right.typ.is_literal: divisor = "r" else: # only apply the non-zero clamp when r is not a constant divisor = ["clamp_nonzero", "r"] if ltyp in ("uint8", "uint256"): arith = ["mod", "l", divisor] else: arith = ["smod", "l", divisor] elif isinstance(self.expr.op, vy_ast.Pow): new_typ = BaseType(ltyp) if self.expr.left.get("value") == 1: return LLLnode.from_list([1], typ=new_typ, pos=pos) if self.expr.left.get("value") == 0: return LLLnode.from_list(["iszero", right], typ=new_typ, pos=pos) if ltyp == "int128": is_signed = True num_bits = 128 elif ltyp == "int256": is_signed = True num_bits = 256 elif ltyp == "uint8": is_signed = False num_bits = 8 else: is_signed = False num_bits = 256 if isinstance(self.expr.left, vy_ast.Int): value = self.expr.left.value upper_bound = calculate_largest_power(value, num_bits, is_signed) + 1 # for signed integers, this also prevents negative values clamp = ["lt", right, upper_bound] return LLLnode.from_list( ["seq", ["assert", clamp], ["exp", left, right]], typ=new_typ, pos=pos, ) elif isinstance(self.expr.right, vy_ast.Int): value = self.expr.right.value upper_bound = calculate_largest_base(value, num_bits, is_signed) + 1 if is_signed: clamp = [ "and", ["slt", left, upper_bound], ["sgt", left, -upper_bound] ] else: clamp = ["lt", left, upper_bound] return LLLnode.from_list( ["seq", ["assert", clamp], ["exp", left, right]], typ=new_typ, pos=pos, ) else: # `a ** b` where neither `a` or `b` are known # TODO this is currently unreachable, once we implement a way to do it safely # remove the check in `vyper/context/types/value/numeric.py` return if arith is None: return arith = LLLnode.from_list(arith, typ=new_typ) p = [ "with", "l", left, [ "with", "r", right, # note clamp_basetype is a noop on [u]int256 # note: clamp_basetype throws on unclampable input clamp_basetype(arith), ], ] return LLLnode.from_list(p, typ=new_typ, pos=pos)
def to_address(expr, args, kwargs, context): # cast to output type so clamp_basetype works lll_node = LLLnode.from_list(args[0], typ="address") return LLLnode.from_list(clamp_basetype(lll_node), typ=BaseType("address"), pos=getpos(expr))
def to_decimal(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "Bytes": if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to decimal", expr, ) # use byte_array_to_num(int128) because it is cheaper to clamp int128 num = byte_array_to_num(in_arg, "int128") return LLLnode.from_list(["mul", num, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: if input_type == "uint256": if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", (in_arg.value * DECIMAL_DIVISOR)): raise InvalidLiteral( f"Number out of range: {in_arg.value}", expr, ) else: return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: return LLLnode.from_list( [ "uclample", ["mul", in_arg, DECIMAL_DIVISOR], ["mload", MemoryPositions.MAXDECIMAL], ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type == "address": return LLLnode.from_list( [ "mul", [ "signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)] ], DECIMAL_DIVISOR, ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type == "bytes32": if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", (in_arg.value * DECIMAL_DIVISOR)): raise InvalidLiteral( f"Number out of range: {in_arg.value}", expr, ) else: return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: return LLLnode.from_list( [ "clamp", ["mload", MemoryPositions.MINDECIMAL], ["mul", in_arg, DECIMAL_DIVISOR], ["mload", MemoryPositions.MAXDECIMAL], ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type == "int256": # cast in_arg so clamp_basetype works in_arg = LLLnode.from_list(in_arg, typ="int128") return LLLnode.from_list( ["mul", clamp_basetype(in_arg), DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type in ("uint8", "int128", "bool"): return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: raise InvalidLiteral(f"Invalid input for decimal: {in_arg}", expr)
def to_int128(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "num_literal": if isinstance(in_arg, int): if not SizeLimits.in_bounds("int128", in_arg): raise InvalidLiteral(f"Number out of range: {in_arg}") return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) elif isinstance(in_arg, Decimal): if not SizeLimits.in_bounds("int128", math.trunc(in_arg)): raise InvalidLiteral( f"Number out of range: {math.trunc(in_arg)}") return LLLnode.from_list(math.trunc(in_arg), typ=BaseType("int128"), pos=getpos(expr)) else: raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}") elif input_type in ("bytes32", "int256"): if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", in_arg.value): raise InvalidLiteral(f"Number out of range: {in_arg.value}", expr) else: return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) else: # cast to output type so clamp_basetype works in_arg = LLLnode.from_list(in_arg, typ="int128") return LLLnode.from_list( clamp_basetype(in_arg), typ=BaseType("int128"), pos=getpos(expr), ) # CMC 20211020: what is the purpose of this .. it lops off 32 bits elif input_type == "address": return LLLnode.from_list( ["signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]], typ=BaseType("int128"), pos=getpos(expr), ) elif input_type in ("String", "Bytes"): if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128", expr, ) return byte_array_to_num(in_arg, "int128") elif input_type == "uint256": if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", in_arg.value): raise InvalidLiteral(f"Number out of range: {in_arg.value}", expr) else: return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) # !! do not use clamp_basetype. check that 0 <= input <= MAX_INT128. res = int_clamp(in_arg, 127, signed=False) return LLLnode.from_list( res, typ="int128", pos=getpos(expr), ) elif input_type == "decimal": # cast to int128 so clamp_basetype works res = LLLnode.from_list(["sdiv", in_arg, DECIMAL_DIVISOR], typ="int128") return LLLnode.from_list(clamp_basetype(res), typ="int128", pos=getpos(expr)) elif input_type in ("bool", "uint8"): # note: for int8, would need signextend return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) else: raise InvalidLiteral(f"Invalid input for int128: {in_arg}", expr)