def test_sha3_32(): lll = ["sha3_32", 0] evm = [ "PUSH1", 0, "PUSH1", 192, "MSTORE", "PUSH1", 32, "PUSH1", 192, "SHA3" ] assert compile_lll.compile_to_assembly(LLLnode.from_list(lll)) == evm assert compile_lll.compile_to_assembly( optimizer.optimize(LLLnode.from_list(lll))) == evm
def to_address(expr, args, kwargs, context): lll_node = [ "with", "_in_arg", args[0], ["seq", address_clamp("_in_arg"), "_in_arg"] ] return LLLnode.from_list(lll_node, typ=BaseType("address"), pos=getpos(expr))
def to_bool(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "Bytes": if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to bool", expr, ) else: num = byte_array_to_num(in_arg, expr, "uint256") return LLLnode.from_list(["iszero", ["iszero", num]], typ=BaseType("bool"), pos=getpos(expr)) else: return LLLnode.from_list(["iszero", ["iszero", in_arg]], typ=BaseType("bool"), pos=getpos(expr))
def lll_compiler(lll, *args, **kwargs): lll = optimizer.optimize(LLLnode.from_list(lll)) bytecode, _ = compile_lll.assembly_to_evm(compile_lll.compile_to_assembly(lll)) abi = kwargs.get("abi") or [] c = w3.eth.contract(abi=abi, bytecode=bytecode) deploy_transaction = c.constructor() tx_hash = deploy_transaction.transact() address = w3.eth.getTransactionReceipt(tx_hash)["contractAddress"] contract = w3.eth.contract( address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract, ) return contract
def to_bytes32(expr, args, kwargs, context): in_arg = args[0] input_type, _len = get_type(in_arg) if input_type == "Bytes": if _len > 32: raise TypeMismatch( f"Unable to convert bytes[{_len}] to bytes32, max length is too " "large.") if in_arg.location == "memory": return LLLnode.from_list(["mload", ["add", in_arg, 32]], typ=BaseType("bytes32")) elif in_arg.location == "storage": return LLLnode.from_list(["sload", ["add", in_arg, 1]], typ=BaseType("bytes32")) else: return LLLnode(value=in_arg.value, args=in_arg.args, typ=BaseType("bytes32"), pos=getpos(expr))
def _merge_memzero(argz): # look for sequential mzero / calldatacopy operations that are zero'ing memory # and merge them into a single calldatacopy mstore_nodes: List = [] initial_offset = 0 total_length = 0 for lll_node in [i for i in argz if i.value != "pass"]: if (lll_node.value == "mstore" and isinstance(lll_node.args[0].value, int) and lll_node.args[1].value == 0): # mstore of a zero value offset = lll_node.args[0].value if not mstore_nodes: initial_offset = offset if initial_offset + total_length == offset: mstore_nodes.append(lll_node) total_length += 32 continue if (lll_node.value == "calldatacopy" and isinstance(lll_node.args[0].value, int) and lll_node.args[1].value == "calldatasize" and isinstance(lll_node.args[2].value, int)): # calldatacopy from the end of calldata - efficient zero'ing via `empty()` offset, length = lll_node.args[0].value, lll_node.args[2].value if not mstore_nodes: initial_offset = offset if initial_offset + total_length == offset: mstore_nodes.append(lll_node) total_length += length continue # if we get this far, the current node is not a zero'ing operation # it's time to apply the optimization if possible if len(mstore_nodes) > 1: new_lll = LLLnode.from_list( ["calldatacopy", initial_offset, "calldatasize", total_length], pos=mstore_nodes[0].pos, ) # replace first zero'ing operation with optimized node and remove the rest idx = argz.index(mstore_nodes[0]) argz[idx] = new_lll for i in mstore_nodes[1:]: argz.remove(i) initial_offset = 0 total_length = 0 mstore_nodes.clear()
def to_uint256(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "num_literal": if isinstance(in_arg, int): if not SizeLimits.in_bounds("uint256", in_arg): raise InvalidLiteral(f"Number out of range: {in_arg}") return LLLnode.from_list(in_arg, typ=BaseType("uint256", ), pos=getpos(expr)) elif isinstance(in_arg, Decimal): if not SizeLimits.in_bounds("uint256", math.trunc(in_arg)): raise InvalidLiteral( f"Number out of range: {math.trunc(in_arg)}") return LLLnode.from_list(math.trunc(in_arg), typ=BaseType("uint256"), pos=getpos(expr)) else: raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}") elif isinstance(in_arg, LLLnode) and input_type in ("int128", "int256"): return LLLnode.from_list(["clampge", in_arg, 0], typ=BaseType("uint256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type == "decimal": return LLLnode.from_list( ["div", ["clampge", in_arg, 0], DECIMAL_DIVISOR], typ=BaseType("uint256"), pos=getpos(expr), ) elif isinstance(in_arg, LLLnode) and input_type == "bool": return LLLnode.from_list(in_arg, typ=BaseType("uint256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type in ("bytes32", "address"): return LLLnode(value=in_arg.value, args=in_arg.args, typ=BaseType("uint256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type == "Bytes": if in_arg.typ.maxlen > 32: raise InvalidLiteral( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to uint256", expr, ) return byte_array_to_num(in_arg, expr, "uint256") else: raise InvalidLiteral(f"Invalid input for uint256: {in_arg}", expr)
def _merge_calldataload(argz): # look for sequential operations copying from calldata to memory # and merge them into a single calldatacopy operation mstore_nodes: List = [] initial_mem_offset = 0 initial_calldata_offset = 0 total_length = 0 for lll_node in [i for i in argz if i.value != "pass"]: if (lll_node.value == "mstore" and isinstance(lll_node.args[0].value, int) and lll_node.args[1].value == "calldataload" and isinstance(lll_node.args[1].args[0].value, int)): # mstore of a zero value mem_offset = lll_node.args[0].value calldata_offset = lll_node.args[1].args[0].value if not mstore_nodes: initial_mem_offset = mem_offset initial_calldata_offset = calldata_offset if (initial_mem_offset + total_length == mem_offset and initial_calldata_offset + total_length == calldata_offset): mstore_nodes.append(lll_node) total_length += 32 continue # if we get this far, the current node is a different operation # it's time to apply the optimization if possible if len(mstore_nodes) > 1: new_lll = LLLnode.from_list( [ "calldatacopy", initial_mem_offset, initial_calldata_offset, total_length ], pos=mstore_nodes[0].pos, ) # replace first copy operation with optimized node and remove the rest idx = argz.index(mstore_nodes[0]) argz[idx] = new_lll for i in mstore_nodes[1:]: argz.remove(i) initial_mem_offset = 0 initial_calldata_offset = 0 total_length = 0 mstore_nodes.clear()
def _to_bytelike(expr, args, kwargs, context, bytetype): if bytetype == "String": ReturnType = StringType elif bytetype == "Bytes": ReturnType = ByteArrayType else: raise TypeMismatch(f"Invalid {bytetype} supplied") in_arg = args[0] if in_arg.typ.maxlen > args[1].slice.value.n: raise TypeMismatch( f"Cannot convert as input {bytetype} are larger than max length", expr, ) return LLLnode( value=in_arg.value, args=in_arg.args, typ=ReturnType(in_arg.typ.maxlen), pos=getpos(expr), location=in_arg.location, )
def compile_to_lll(input_file, output_formats, show_gas_estimates=False): with open(input_file) as fh: s_expressions = parse_s_exp(fh.read()) if show_gas_estimates: LLLnode.repr_show_gas = True compiler_data = {} lll = LLLnode.from_list(s_expressions[0]) if "ir" in output_formats: compiler_data["ir"] = lll if "opt_ir" in output_formats: compiler_data["opt_ir"] = optimizer.optimize(lll) asm = compile_lll.compile_to_assembly(lll) if "asm" in output_formats: compiler_data["asm"] = asm if "bytecode" in output_formats: (bytecode, _srcmap) = compile_lll.assembly_to_evm(asm) compiler_data["bytecode"] = "0x" + bytecode.hex() return compiler_data
def apply_general_optimizations(node: LLLnode) -> LLLnode: # TODO refactor this into several functions argz = [apply_general_optimizations(arg) for arg in node.args] if node.value == "seq": _merge_memzero(argz) _merge_calldataload(argz) if node.value in arith and int_at(argz, 0) and int_at(argz, 1): left, right = get_int_at(argz, 0), get_int_at(argz, 1) # `node.value in arith` implies that `node.value` is a `str` calcer, symb = arith[str(node.value)] new_value = calcer(left, right) if argz[0].annotation and argz[1].annotation: annotation = argz[0].annotation + symb + argz[1].annotation elif argz[0].annotation or argz[1].annotation: annotation = ((argz[0].annotation or str(left)) + symb + (argz[1].annotation or str(right))) else: annotation = "" return LLLnode( new_value, [], node.typ, None, node.pos, annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) elif _is_constant_add(node, argz): # `node.value in arith` implies that `node.value` is a `str` calcer, symb = arith[str(node.value)] if argz[0].annotation and argz[1].args[0].annotation: annotation = argz[0].annotation + symb + argz[1].args[0].annotation elif argz[0].annotation or argz[1].args[0].annotation: annotation = ( (argz[0].annotation or str(argz[0].value)) + symb + (argz[1].args[0].annotation or str(argz[1].args[0].value))) else: annotation = "" return LLLnode( "add", [ LLLnode(int(argz[0].value) + int(argz[1].args[0].value), annotation=annotation), argz[1].args[1], ], node.typ, None, annotation=node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) elif node.value == "add" and get_int_at(argz, 0) == 0: return LLLnode( argz[1].value, argz[1].args, node.typ, node.location, node.pos, annotation=argz[1].annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) elif node.value == "add" and get_int_at(argz, 1) == 0: return LLLnode( argz[0].value, argz[0].args, node.typ, node.location, node.pos, argz[0].annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) elif node.value == "clamp" and int_at(argz, 0) and int_at( argz, 1) and int_at(argz, 2): if get_int_at(argz, 0, True) > get_int_at(argz, 1, True): # type: ignore raise Exception("Clamp always fails") elif get_int_at(argz, 1, True) > get_int_at(argz, 2, True): # type: ignore raise Exception("Clamp always fails") else: return argz[1] elif node.value == "clamp" and int_at(argz, 0) and int_at(argz, 1): if get_int_at(argz, 0, True) > get_int_at(argz, 1, True): # type: ignore raise Exception("Clamp always fails") else: return LLLnode( "clample", [argz[1], argz[2]], node.typ, node.location, node.pos, node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) elif node.value == "clamp_nonzero" and int_at(argz, 0): if get_int_at(argz, 0) != 0: return LLLnode( argz[0].value, [], node.typ, node.location, node.pos, node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) else: raise Exception("Clamp always fails") # [eq, x, 0] is the same as [iszero, x]. elif node.value == "eq" and int_at(argz, 1) and argz[1].value == 0: return LLLnode( "iszero", [argz[0]], node.typ, node.location, node.pos, node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) # [ne, x, y] has the same truthyness as [xor, x, y] # rewrite 'ne' as 'xor' in places where truthy is accepted. elif node.value in ("if", "if_unchecked", "assert") and argz[0].value == "ne": argz[0] = LLLnode.from_list(["xor"] + argz[0].args) # type: ignore return LLLnode.from_list( [node.value] + argz, # type: ignore typ=node.typ, location=node.location, pos=node.pos, annotation=node.annotation, # let from_list handle valency and gas_estimate ) elif node.value == "seq": xs: List[Any] = [] for arg in argz: if arg.value == "seq": xs.extend(arg.args) else: xs.append(arg) return LLLnode( node.value, xs, node.typ, node.location, node.pos, node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) elif node.total_gas is not None: o = LLLnode( node.value, argz, node.typ, node.location, node.pos, node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, ) o.total_gas = node.total_gas - node.gas + o.gas o.func_name = node.func_name return o else: return LLLnode( node.value, argz, node.typ, node.location, node.pos, node.annotation, add_gas_estimate=node.add_gas_estimate, valency=node.valency, )
def to_int128(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "num_literal": if isinstance(in_arg, int): if not SizeLimits.in_bounds("int128", in_arg): raise InvalidLiteral(f"Number out of range: {in_arg}") return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) elif isinstance(in_arg, Decimal): if not SizeLimits.in_bounds("int128", math.trunc(in_arg)): raise InvalidLiteral( f"Number out of range: {math.trunc(in_arg)}") return LLLnode.from_list(math.trunc(in_arg), typ=BaseType("int128"), pos=getpos(expr)) else: raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}") elif input_type in ("bytes32", "int256"): if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", in_arg.value): raise InvalidLiteral(f"Number out of range: {in_arg.value}", expr) else: return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) else: return LLLnode.from_list( int128_clamp(in_arg), typ=BaseType("int128"), pos=getpos(expr), ) elif input_type == "address": return LLLnode.from_list( ["signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]], typ=BaseType("int128"), pos=getpos(expr), ) elif input_type in ("String", "Bytes"): if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128", expr, ) return byte_array_to_num(in_arg, expr, "int128") elif input_type == "uint256": if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", in_arg.value): raise InvalidLiteral(f"Number out of range: {in_arg.value}", expr) else: return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) else: return LLLnode.from_list( ["uclample", in_arg, ["mload", MemoryPositions.MAX_INT128]], typ=BaseType("int128"), pos=getpos(expr), ) elif input_type == "decimal": return LLLnode.from_list( int128_clamp(["sdiv", in_arg, DECIMAL_DIVISOR]), typ=BaseType("int128"), pos=getpos(expr), ) elif input_type == "bool": return LLLnode.from_list(in_arg, typ=BaseType("int128"), pos=getpos(expr)) else: raise InvalidLiteral(f"Invalid input for int128: {in_arg}", expr)
def to_decimal(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "Bytes": if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to decimal", expr, ) num = byte_array_to_num(in_arg, expr, "int128") return LLLnode.from_list(["mul", num, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: if input_type == "uint256": if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", (in_arg.value * DECIMAL_DIVISOR)): raise InvalidLiteral( f"Number out of range: {in_arg.value}", expr, ) else: return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: return LLLnode.from_list( [ "uclample", ["mul", in_arg, DECIMAL_DIVISOR], ["mload", MemoryPositions.MAXDECIMAL], ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type == "address": return LLLnode.from_list( [ "mul", [ "signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)] ], DECIMAL_DIVISOR, ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type == "bytes32": if in_arg.typ.is_literal: if not SizeLimits.in_bounds("int128", (in_arg.value * DECIMAL_DIVISOR)): raise InvalidLiteral( f"Number out of range: {in_arg.value}", expr, ) else: return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: return LLLnode.from_list( [ "clamp", ["mload", MemoryPositions.MINDECIMAL], ["mul", in_arg, DECIMAL_DIVISOR], ["mload", MemoryPositions.MAXDECIMAL], ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type == "int256": return LLLnode.from_list( [ "seq", int128_clamp(in_arg), ["mul", in_arg, DECIMAL_DIVISOR] ], typ=BaseType("decimal"), pos=getpos(expr), ) elif input_type in ("int128", "bool"): return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR], typ=BaseType("decimal"), pos=getpos(expr)) else: raise InvalidLiteral(f"Invalid input for decimal: {in_arg}", expr)
def to_int256(expr, args, kwargs, context): in_arg = args[0] input_type, _ = get_type(in_arg) if input_type == "num_literal": if isinstance(in_arg, int): if not SizeLimits.in_bounds("int256", in_arg): raise InvalidLiteral(f"Number out of range: {in_arg}") return LLLnode.from_list(in_arg, typ=BaseType("int256", ), pos=getpos(expr)) elif isinstance(in_arg, Decimal): if not SizeLimits.in_bounds("int256", math.trunc(in_arg)): raise InvalidLiteral( f"Number out of range: {math.trunc(in_arg)}") return LLLnode.from_list(math.trunc(in_arg), typ=BaseType("int256"), pos=getpos(expr)) else: raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}") elif isinstance(in_arg, LLLnode) and input_type == "int128": return LLLnode.from_list(in_arg, typ=BaseType("int256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type == "uint256": if version_check(begin="constantinople"): upper_bound = ["shl", 255, 1] else: upper_bound = -(2**255) return LLLnode.from_list(["uclamplt", in_arg, upper_bound], typ=BaseType("int256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type == "decimal": return LLLnode.from_list( ["sdiv", in_arg, DECIMAL_DIVISOR], typ=BaseType("int256"), pos=getpos(expr), ) elif isinstance(in_arg, LLLnode) and input_type == "bool": return LLLnode.from_list(in_arg, typ=BaseType("int256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type in ("bytes32", "address"): return LLLnode(value=in_arg.value, args=in_arg.args, typ=BaseType("int256"), pos=getpos(expr)) elif isinstance(in_arg, LLLnode) and input_type in ("Bytes", "String"): if in_arg.typ.maxlen > 32: raise TypeMismatch( f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int256", expr, ) return byte_array_to_num(in_arg, expr, "int256") else: raise InvalidLiteral(f"Invalid input for int256: {in_arg}", expr)