Ejemplo n.º 1
0
def test_sha3_32():
    lll = ["sha3_32", 0]
    evm = [
        "PUSH1", 0, "PUSH1", 192, "MSTORE", "PUSH1", 32, "PUSH1", 192, "SHA3"
    ]
    assert compile_lll.compile_to_assembly(LLLnode.from_list(lll)) == evm
    assert compile_lll.compile_to_assembly(
        optimizer.optimize(LLLnode.from_list(lll))) == evm
Ejemplo n.º 2
0
def to_address(expr, args, kwargs, context):
    lll_node = [
        "with", "_in_arg", args[0],
        ["seq", address_clamp("_in_arg"), "_in_arg"]
    ]
    return LLLnode.from_list(lll_node,
                             typ=BaseType("address"),
                             pos=getpos(expr))
Ejemplo n.º 3
0
def to_bool(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "Bytes":
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to bool",
                expr,
            )
        else:
            num = byte_array_to_num(in_arg, expr, "uint256")
            return LLLnode.from_list(["iszero", ["iszero", num]],
                                     typ=BaseType("bool"),
                                     pos=getpos(expr))

    else:
        return LLLnode.from_list(["iszero", ["iszero", in_arg]],
                                 typ=BaseType("bool"),
                                 pos=getpos(expr))
Ejemplo n.º 4
0
 def lll_compiler(lll, *args, **kwargs):
     lll = optimizer.optimize(LLLnode.from_list(lll))
     bytecode, _ = compile_lll.assembly_to_evm(compile_lll.compile_to_assembly(lll))
     abi = kwargs.get("abi") or []
     c = w3.eth.contract(abi=abi, bytecode=bytecode)
     deploy_transaction = c.constructor()
     tx_hash = deploy_transaction.transact()
     address = w3.eth.getTransactionReceipt(tx_hash)["contractAddress"]
     contract = w3.eth.contract(
         address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract,
     )
     return contract
Ejemplo n.º 5
0
def to_bytes32(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _len = get_type(in_arg)

    if input_type == "Bytes":
        if _len > 32:
            raise TypeMismatch(
                f"Unable to convert bytes[{_len}] to bytes32, max length is too "
                "large.")

        if in_arg.location == "memory":
            return LLLnode.from_list(["mload", ["add", in_arg, 32]],
                                     typ=BaseType("bytes32"))
        elif in_arg.location == "storage":
            return LLLnode.from_list(["sload", ["add", in_arg, 1]],
                                     typ=BaseType("bytes32"))

    else:
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType("bytes32"),
                       pos=getpos(expr))
Ejemplo n.º 6
0
def _merge_memzero(argz):
    # look for sequential mzero / calldatacopy operations that are zero'ing memory
    # and merge them into a single calldatacopy
    mstore_nodes: List = []
    initial_offset = 0
    total_length = 0
    for lll_node in [i for i in argz if i.value != "pass"]:
        if (lll_node.value == "mstore"
                and isinstance(lll_node.args[0].value, int)
                and lll_node.args[1].value == 0):
            # mstore of a zero value
            offset = lll_node.args[0].value
            if not mstore_nodes:
                initial_offset = offset
            if initial_offset + total_length == offset:
                mstore_nodes.append(lll_node)
                total_length += 32
                continue

        if (lll_node.value == "calldatacopy"
                and isinstance(lll_node.args[0].value, int)
                and lll_node.args[1].value == "calldatasize"
                and isinstance(lll_node.args[2].value, int)):
            # calldatacopy from the end of calldata - efficient zero'ing via `empty()`
            offset, length = lll_node.args[0].value, lll_node.args[2].value
            if not mstore_nodes:
                initial_offset = offset
            if initial_offset + total_length == offset:
                mstore_nodes.append(lll_node)
                total_length += length
                continue

        # if we get this far, the current node is not a zero'ing operation
        # it's time to apply the optimization if possible
        if len(mstore_nodes) > 1:
            new_lll = LLLnode.from_list(
                ["calldatacopy", initial_offset, "calldatasize", total_length],
                pos=mstore_nodes[0].pos,
            )
            # replace first zero'ing operation with optimized node and remove the rest
            idx = argz.index(mstore_nodes[0])
            argz[idx] = new_lll
            for i in mstore_nodes[1:]:
                argz.remove(i)

        initial_offset = 0
        total_length = 0
        mstore_nodes.clear()
Ejemplo n.º 7
0
def to_uint256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("uint256", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("uint256", ),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("uint256", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("uint256"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type in ("int128", "int256"):
        return LLLnode.from_list(["clampge", in_arg, 0],
                                 typ=BaseType("uint256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "decimal":
        return LLLnode.from_list(
            ["div", ["clampge", in_arg, 0], DECIMAL_DIVISOR],
            typ=BaseType("uint256"),
            pos=getpos(expr),
        )

    elif isinstance(in_arg, LLLnode) and input_type == "bool":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("uint256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("bytes32", "address"):
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType("uint256"),
                       pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "Bytes":
        if in_arg.typ.maxlen > 32:
            raise InvalidLiteral(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to uint256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, "uint256")

    else:
        raise InvalidLiteral(f"Invalid input for uint256: {in_arg}", expr)
Ejemplo n.º 8
0
def _merge_calldataload(argz):
    # look for sequential operations copying from calldata to memory
    # and merge them into a single calldatacopy operation
    mstore_nodes: List = []
    initial_mem_offset = 0
    initial_calldata_offset = 0
    total_length = 0
    for lll_node in [i for i in argz if i.value != "pass"]:
        if (lll_node.value == "mstore"
                and isinstance(lll_node.args[0].value, int)
                and lll_node.args[1].value == "calldataload"
                and isinstance(lll_node.args[1].args[0].value, int)):
            # mstore of a zero value
            mem_offset = lll_node.args[0].value
            calldata_offset = lll_node.args[1].args[0].value
            if not mstore_nodes:
                initial_mem_offset = mem_offset
                initial_calldata_offset = calldata_offset
            if (initial_mem_offset + total_length == mem_offset and
                    initial_calldata_offset + total_length == calldata_offset):
                mstore_nodes.append(lll_node)
                total_length += 32
                continue

        # if we get this far, the current node is a different operation
        # it's time to apply the optimization if possible
        if len(mstore_nodes) > 1:
            new_lll = LLLnode.from_list(
                [
                    "calldatacopy", initial_mem_offset,
                    initial_calldata_offset, total_length
                ],
                pos=mstore_nodes[0].pos,
            )
            # replace first copy operation with optimized node and remove the rest
            idx = argz.index(mstore_nodes[0])
            argz[idx] = new_lll
            for i in mstore_nodes[1:]:
                argz.remove(i)

        initial_mem_offset = 0
        initial_calldata_offset = 0
        total_length = 0
        mstore_nodes.clear()
Ejemplo n.º 9
0
def _to_bytelike(expr, args, kwargs, context, bytetype):
    if bytetype == "String":
        ReturnType = StringType
    elif bytetype == "Bytes":
        ReturnType = ByteArrayType
    else:
        raise TypeMismatch(f"Invalid {bytetype} supplied")

    in_arg = args[0]
    if in_arg.typ.maxlen > args[1].slice.value.n:
        raise TypeMismatch(
            f"Cannot convert as input {bytetype} are larger than max length",
            expr,
        )

    return LLLnode(
        value=in_arg.value,
        args=in_arg.args,
        typ=ReturnType(in_arg.typ.maxlen),
        pos=getpos(expr),
        location=in_arg.location,
    )
Ejemplo n.º 10
0
def compile_to_lll(input_file, output_formats, show_gas_estimates=False):
    with open(input_file) as fh:
        s_expressions = parse_s_exp(fh.read())

    if show_gas_estimates:
        LLLnode.repr_show_gas = True

    compiler_data = {}
    lll = LLLnode.from_list(s_expressions[0])
    if "ir" in output_formats:
        compiler_data["ir"] = lll

    if "opt_ir" in output_formats:
        compiler_data["opt_ir"] = optimizer.optimize(lll)

    asm = compile_lll.compile_to_assembly(lll)
    if "asm" in output_formats:
        compiler_data["asm"] = asm

    if "bytecode" in output_formats:
        (bytecode, _srcmap) = compile_lll.assembly_to_evm(asm)
        compiler_data["bytecode"] = "0x" + bytecode.hex()

    return compiler_data
Ejemplo n.º 11
0
def apply_general_optimizations(node: LLLnode) -> LLLnode:
    # TODO refactor this into several functions
    argz = [apply_general_optimizations(arg) for arg in node.args]

    if node.value == "seq":
        _merge_memzero(argz)
        _merge_calldataload(argz)

    if node.value in arith and int_at(argz, 0) and int_at(argz, 1):
        left, right = get_int_at(argz, 0), get_int_at(argz, 1)
        # `node.value in arith` implies that `node.value` is a `str`
        calcer, symb = arith[str(node.value)]
        new_value = calcer(left, right)
        if argz[0].annotation and argz[1].annotation:
            annotation = argz[0].annotation + symb + argz[1].annotation
        elif argz[0].annotation or argz[1].annotation:
            annotation = ((argz[0].annotation or str(left)) + symb +
                          (argz[1].annotation or str(right)))
        else:
            annotation = ""
        return LLLnode(
            new_value,
            [],
            node.typ,
            None,
            node.pos,
            annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif _is_constant_add(node, argz):
        # `node.value in arith` implies that `node.value` is a `str`
        calcer, symb = arith[str(node.value)]
        if argz[0].annotation and argz[1].args[0].annotation:
            annotation = argz[0].annotation + symb + argz[1].args[0].annotation
        elif argz[0].annotation or argz[1].args[0].annotation:
            annotation = (
                (argz[0].annotation or str(argz[0].value)) + symb +
                (argz[1].args[0].annotation or str(argz[1].args[0].value)))
        else:
            annotation = ""
        return LLLnode(
            "add",
            [
                LLLnode(int(argz[0].value) + int(argz[1].args[0].value),
                        annotation=annotation),
                argz[1].args[1],
            ],
            node.typ,
            None,
            annotation=node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.value == "add" and get_int_at(argz, 0) == 0:
        return LLLnode(
            argz[1].value,
            argz[1].args,
            node.typ,
            node.location,
            node.pos,
            annotation=argz[1].annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.value == "add" and get_int_at(argz, 1) == 0:
        return LLLnode(
            argz[0].value,
            argz[0].args,
            node.typ,
            node.location,
            node.pos,
            argz[0].annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.value == "clamp" and int_at(argz, 0) and int_at(
            argz, 1) and int_at(argz, 2):
        if get_int_at(argz, 0, True) > get_int_at(argz, 1,
                                                  True):  # type: ignore
            raise Exception("Clamp always fails")
        elif get_int_at(argz, 1, True) > get_int_at(argz, 2,
                                                    True):  # type: ignore
            raise Exception("Clamp always fails")
        else:
            return argz[1]
    elif node.value == "clamp" and int_at(argz, 0) and int_at(argz, 1):
        if get_int_at(argz, 0, True) > get_int_at(argz, 1,
                                                  True):  # type: ignore
            raise Exception("Clamp always fails")
        else:
            return LLLnode(
                "clample",
                [argz[1], argz[2]],
                node.typ,
                node.location,
                node.pos,
                node.annotation,
                add_gas_estimate=node.add_gas_estimate,
                valency=node.valency,
            )
    elif node.value == "clamp_nonzero" and int_at(argz, 0):
        if get_int_at(argz, 0) != 0:
            return LLLnode(
                argz[0].value,
                [],
                node.typ,
                node.location,
                node.pos,
                node.annotation,
                add_gas_estimate=node.add_gas_estimate,
                valency=node.valency,
            )
        else:
            raise Exception("Clamp always fails")
    # [eq, x, 0] is the same as [iszero, x].
    elif node.value == "eq" and int_at(argz, 1) and argz[1].value == 0:
        return LLLnode(
            "iszero",
            [argz[0]],
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    # [ne, x, y] has the same truthyness as [xor, x, y]
    # rewrite 'ne' as 'xor' in places where truthy is accepted.
    elif node.value in ("if", "if_unchecked",
                        "assert") and argz[0].value == "ne":
        argz[0] = LLLnode.from_list(["xor"] + argz[0].args)  # type: ignore
        return LLLnode.from_list(
            [node.value] + argz,  # type: ignore
            typ=node.typ,
            location=node.location,
            pos=node.pos,
            annotation=node.annotation,
            # let from_list handle valency and gas_estimate
        )
    elif node.value == "seq":
        xs: List[Any] = []
        for arg in argz:
            if arg.value == "seq":
                xs.extend(arg.args)
            else:
                xs.append(arg)
        return LLLnode(
            node.value,
            xs,
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
    elif node.total_gas is not None:
        o = LLLnode(
            node.value,
            argz,
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
        o.total_gas = node.total_gas - node.gas + o.gas
        o.func_name = node.func_name
        return o
    else:
        return LLLnode(
            node.value,
            argz,
            node.typ,
            node.location,
            node.pos,
            node.annotation,
            add_gas_estimate=node.add_gas_estimate,
            valency=node.valency,
        )
Ejemplo n.º 12
0
def to_int128(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("int128", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("int128"),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("int128", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("int128"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif input_type in ("bytes32", "int256"):
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds("int128", in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType("int128"),
                                         pos=getpos(expr))
        else:
            return LLLnode.from_list(
                int128_clamp(in_arg),
                typ=BaseType("int128"),
                pos=getpos(expr),
            )

    elif input_type == "address":
        return LLLnode.from_list(
            ["signextend", 15, ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]],
            typ=BaseType("int128"),
            pos=getpos(expr),
        )

    elif input_type in ("String", "Bytes"):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int128",
                expr,
            )
        return byte_array_to_num(in_arg, expr, "int128")

    elif input_type == "uint256":
        if in_arg.typ.is_literal:
            if not SizeLimits.in_bounds("int128", in_arg.value):
                raise InvalidLiteral(f"Number out of range: {in_arg.value}",
                                     expr)
            else:
                return LLLnode.from_list(in_arg,
                                         typ=BaseType("int128"),
                                         pos=getpos(expr))

        else:
            return LLLnode.from_list(
                ["uclample", in_arg, ["mload", MemoryPositions.MAX_INT128]],
                typ=BaseType("int128"),
                pos=getpos(expr),
            )

    elif input_type == "decimal":
        return LLLnode.from_list(
            int128_clamp(["sdiv", in_arg, DECIMAL_DIVISOR]),
            typ=BaseType("int128"),
            pos=getpos(expr),
        )

    elif input_type == "bool":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int128"),
                                 pos=getpos(expr))

    else:
        raise InvalidLiteral(f"Invalid input for int128: {in_arg}", expr)
Ejemplo n.º 13
0
def to_decimal(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "Bytes":
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to decimal",
                expr,
            )
        num = byte_array_to_num(in_arg, expr, "int128")
        return LLLnode.from_list(["mul", num, DECIMAL_DIVISOR],
                                 typ=BaseType("decimal"),
                                 pos=getpos(expr))

    else:
        if input_type == "uint256":
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds("int128",
                                            (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteral(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR],
                                             typ=BaseType("decimal"),
                                             pos=getpos(expr))
            else:
                return LLLnode.from_list(
                    [
                        "uclample",
                        ["mul", in_arg, DECIMAL_DIVISOR],
                        ["mload", MemoryPositions.MAXDECIMAL],
                    ],
                    typ=BaseType("decimal"),
                    pos=getpos(expr),
                )

        elif input_type == "address":
            return LLLnode.from_list(
                [
                    "mul",
                    [
                        "signextend", 15,
                        ["and", in_arg, (SizeLimits.ADDRSIZE - 1)]
                    ],
                    DECIMAL_DIVISOR,
                ],
                typ=BaseType("decimal"),
                pos=getpos(expr),
            )

        elif input_type == "bytes32":
            if in_arg.typ.is_literal:
                if not SizeLimits.in_bounds("int128",
                                            (in_arg.value * DECIMAL_DIVISOR)):
                    raise InvalidLiteral(
                        f"Number out of range: {in_arg.value}",
                        expr,
                    )
                else:
                    return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR],
                                             typ=BaseType("decimal"),
                                             pos=getpos(expr))
            else:
                return LLLnode.from_list(
                    [
                        "clamp",
                        ["mload", MemoryPositions.MINDECIMAL],
                        ["mul", in_arg, DECIMAL_DIVISOR],
                        ["mload", MemoryPositions.MAXDECIMAL],
                    ],
                    typ=BaseType("decimal"),
                    pos=getpos(expr),
                )

        elif input_type == "int256":
            return LLLnode.from_list(
                [
                    "seq",
                    int128_clamp(in_arg), ["mul", in_arg, DECIMAL_DIVISOR]
                ],
                typ=BaseType("decimal"),
                pos=getpos(expr),
            )

        elif input_type in ("int128", "bool"):
            return LLLnode.from_list(["mul", in_arg, DECIMAL_DIVISOR],
                                     typ=BaseType("decimal"),
                                     pos=getpos(expr))

        else:
            raise InvalidLiteral(f"Invalid input for decimal: {in_arg}", expr)
Ejemplo n.º 14
0
def to_int256(expr, args, kwargs, context):
    in_arg = args[0]
    input_type, _ = get_type(in_arg)

    if input_type == "num_literal":
        if isinstance(in_arg, int):
            if not SizeLimits.in_bounds("int256", in_arg):
                raise InvalidLiteral(f"Number out of range: {in_arg}")
            return LLLnode.from_list(in_arg,
                                     typ=BaseType("int256", ),
                                     pos=getpos(expr))
        elif isinstance(in_arg, Decimal):
            if not SizeLimits.in_bounds("int256", math.trunc(in_arg)):
                raise InvalidLiteral(
                    f"Number out of range: {math.trunc(in_arg)}")
            return LLLnode.from_list(math.trunc(in_arg),
                                     typ=BaseType("int256"),
                                     pos=getpos(expr))
        else:
            raise InvalidLiteral(f"Unknown numeric literal type: {in_arg}")

    elif isinstance(in_arg, LLLnode) and input_type == "int128":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "uint256":
        if version_check(begin="constantinople"):
            upper_bound = ["shl", 255, 1]
        else:
            upper_bound = -(2**255)
        return LLLnode.from_list(["uclamplt", in_arg, upper_bound],
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type == "decimal":
        return LLLnode.from_list(
            ["sdiv", in_arg, DECIMAL_DIVISOR],
            typ=BaseType("int256"),
            pos=getpos(expr),
        )

    elif isinstance(in_arg, LLLnode) and input_type == "bool":
        return LLLnode.from_list(in_arg,
                                 typ=BaseType("int256"),
                                 pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("bytes32", "address"):
        return LLLnode(value=in_arg.value,
                       args=in_arg.args,
                       typ=BaseType("int256"),
                       pos=getpos(expr))

    elif isinstance(in_arg, LLLnode) and input_type in ("Bytes", "String"):
        if in_arg.typ.maxlen > 32:
            raise TypeMismatch(
                f"Cannot convert bytes array of max length {in_arg.typ.maxlen} to int256",
                expr,
            )
        return byte_array_to_num(in_arg, expr, "int256")

    else:
        raise InvalidLiteral(f"Invalid input for int256: {in_arg}", expr)