Exemplo n.º 1
0
 def visitVariableDeclaration(self, ast: VariableDeclaration):
     ast.modified_values[InstanceTarget(ast)] = None
Exemplo n.º 2
0
# BUILTIN SPECIAL TYPE DEFINITIONS
from zkay.zkay_ast.ast import AnnotatedTypeName, FunctionTypeName, Parameter, Identifier, StructDefinition, \
    VariableDeclaration, TypeName, StateVariableDeclaration, UserDefinedTypeName, StructTypeName, Block, ConstructorOrFunctionDefinition
from zkay.zkay_ast.pointers.parent_setter import set_parents

array_length_member = VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                          Identifier('length'))


class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
Exemplo n.º 3
0
class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
                Identifier('send'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'],
                [Parameter([], AnnotatedTypeName.bool_all(), Identifier(''))],
                Block([])),
            ConstructorOrFunctionDefinition(
                Identifier('transfer'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'], [], Block([])),
        ])
    address_payable_struct.members[1].can_be_private = False
    address_payable_struct.members[2].can_be_private = False
    set_parents(address_payable_struct)

    msg_struct: StructDefinition = StructDefinition(Identifier('<msg>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('sender')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('value')),
    ])
    set_parents(msg_struct)

    block_struct: StructDefinition = StructDefinition(Identifier('<block>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('coinbase')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('difficulty')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gaslimit')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('number')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('timestamp')),
    ])
    set_parents(block_struct)

    tx_struct: StructDefinition = StructDefinition(Identifier('<tx>'), [
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gasprice')),
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('origin')),
    ])
    set_parents(tx_struct)
Exemplo n.º 4
0
 def visitVariableDeclaration(self, ast: VariableDeclaration):
     if ast.annotated_type.is_private():
         ast.storage_location = 'memory'
     return self.visit_children(ast)
Exemplo n.º 5
0
    def handle_function_body(self, ast: ConstructorOrFunctionDefinition):
        """
        Return offchain simulation python code for the body of function ast.

        In addition to what the original code does, the generated python code also:

        * checks that internal functions are not called externally
        * processes arguments (encryption, address wrapping for external calls),
        * introduces msg, block and tx objects as local variables (populated with current blockchain state)
        * serializes the public circuit outputs and the private circuit inputs, which are obtained during \
          simulation into int lists so that they can be passed to the proof generation
        * generates the NIZK proof (if needed)
        * calls/issues transaction with transformed arguments ((encrypted) original args, out array, proof)
          (or deploys the contract in case of the constructor)
        """
        preamble_str = ''
        if ast.is_external:
            preamble_str += f'assert {IS_EXTERNAL_CALL}\n'
        preamble_str += f'msg, block, tx = {api("get_special_variables")}()\n' \
                        f'now = block.timestamp\n'
        circuit = self.current_circ

        if circuit and circuit.sec_idfs:
            priv_struct = StructDefinition(None, [
                VariableDeclaration([], AnnotatedTypeName(sec_idf.t), sec_idf)
                for sec_idf in circuit.sec_idfs
            ])
            preamble_str += f'\n{PRIV_VALUES_NAME}: Dict[str, Any] = {self.get_default_value(StructTypeName([], priv_struct))}\n'

        all_params = ', '.join(
            [f'{self.visit(param.idf)}' for param in self.current_params])
        if ast.can_be_external:
            # Wrap address strings in AddressValue object for external calls
            address_params = [
                self.visit(param.idf) for param in self.current_params
                if param.annotated_type.zkay_type.is_address()
            ]
            if address_params:
                assign_addr_str = f"{', '.join(address_params)} = {', '.join([f'AddressValue({p})' for p in address_params])}"
                preamble_str += f'\n{self.do_if_external(ast, [assign_addr_str])}\n'

        if ast.can_be_external and circuit:
            # Encrypt parameters and add private circuit inputs (plain + randomness)
            enc_param_str = ''
            for arg in self.current_params:
                if arg.annotated_type.is_cipher():
                    assert isinstance(arg.annotated_type.type_name, CipherText)
                    cipher: CipherText = arg.annotated_type.type_name
                    pname = self.visit(arg.idf)
                    plain_val = pname
                    plain_t = cipher.plain_type.type_name
                    crypto_params = cipher.crypto_params
                    crypto_str = f'crypto_backend="{crypto_params.crypto_name}"'
                    if plain_t.is_signed_numeric and crypto_params.enc_signed_as_unsigned:
                        plain_val = self.handle_cast(
                            pname,
                            UintTypeName(f'uint{plain_t.elem_bitwidth}'))
                    enc_param_str += f'{self.get_priv_value(arg.idf.name)} = {plain_val}\n'
                    if crypto_params.is_symmetric_cipher():
                        my_pk = f'{api("get_my_pk")}("{crypto_params.crypto_name}")[0]'
                        enc_expr = f'{api("enc")}({self.get_priv_value(arg.idf.name)}, {crypto_str})'
                        enc_param_str += f'{pname} = CipherValue({enc_expr}[0][:-1] + ({my_pk}, ), {crypto_str})\n'
                    else:
                        enc_expr = f'{api("enc")}({self.get_priv_value(arg.idf.name)}, {crypto_str})'
                        enc_param_str += f'{pname}, {self.get_priv_value(f"{arg.idf.name}_R")} = {enc_expr}\n'

            enc_param_comment_str = '\n# Encrypt parameters' if enc_param_str else ''
            enc_param_str = enc_param_str[:-1] if enc_param_str else ''

            actual_params_assign_str = f"actual_params = [{all_params}]"

            out_var_decl_str = f'{cfg.zk_out_name}: List[int] = [0 for _ in range({circuit.out_size_trans})]'
            out_var_decl_str += f'\nactual_params.append({cfg.zk_out_name})'

            pre_body_code = self.do_if_external(ast, [
                enc_param_comment_str, enc_param_str, actual_params_assign_str,
                out_var_decl_str
            ])
        elif ast.can_be_external:
            pre_body_code = f'actual_params = [{all_params}]'
        else:
            pre_body_code = ''

        # Simulate public contract to compute in_values (state variable values are pulled from blockchain if necessary)
        # (out values are also computed when encountered, by locally evaluating and encrypting
        # the corresponding private expressions)
        body_str = self.visit(ast.body).strip()

        serialize_str = ''
        if circuit is not None:
            if circuit.output_idfs:
                out_elemwidths = ', '.join([
                    str(out.t.elem_bitwidth)
                    if out.t.is_primitive_type() else '0'
                    for out in circuit.output_idfs
                ])
                serialize_str += f'\n{cfg.zk_out_name}[{cfg.zk_out_name}_start_idx:{cfg.zk_out_name}_start_idx + {circuit.out_size}] = ' \
                                 f'{api("serialize_circuit_outputs")}(zk__data, [{out_elemwidths}])'
            if circuit.sec_idfs:
                sec_elemwidths = ', '.join([
                    str(sec.t.elem_bitwidth)
                    if sec.t.is_primitive_type() else '0'
                    for sec in circuit.sec_idfs
                ])
                serialize_str += f'\n{api("serialize_private_inputs")}({PRIV_VALUES_NAME}, [{sec_elemwidths}])'
        if serialize_str:
            serialize_str = f'\n# Serialize circuit outputs and/or secret circuit inputs\n' + serialize_str.lstrip(
            )

        body_code = '\n'.join(
            dedent(s) for s in [
                f'\n## BEGIN Simulate body',
                body_str,
                '## END Simulate body',
                serialize_str,
            ] if s) + '\n'

        # Add proof to actual argument list (when required)
        generate_proof_str = ''
        fname = f"'{ast.name}'"
        if ast.can_be_external and circuit and ast.has_side_effects:
            generate_proof_str += '\n'.join([
                '\n#Generate proof',
                f"proof = {api('gen_proof')}({fname}, {cfg.zk_in_name}, {cfg.zk_out_name})",
                'actual_params.append(proof)'
            ])

        should_encrypt = ", ".join([
            str(p.annotated_type.is_cipher())
            for p in self.current_f.parameters
        ])
        if ast.is_constructor:
            invoke_transact_str = f'''
            # Deploy contract
            {api("deploy")}(actual_params, [{should_encrypt}]{", wei_amount=wei_amount" if ast.is_payable else ""})
            '''
        elif ast.has_side_effects:
            invoke_transact_str = f'''
            # Invoke public transaction
            return {api("transact")}({fname}, actual_params, [{should_encrypt}]{", wei_amount=wei_amount" if ast.is_payable else ""})
            '''
        elif ast.return_parameters:
            constructors = []
            for retparam in ast.return_parameters:
                t = retparam.annotated_type.type_name
                if isinstance(t, CipherText):
                    constr = f'(True, "{t.crypto_params.crypto_name}", {self._get_type_constr(t.plain_type.type_name)})'
                else:
                    constr = f'(False, None, {self._get_type_constr(t)})'
                constructors.append(constr)
            constructors = f"[{', '.join(constructors)}]"

            invoke_transact_str = f'''
            # Call pure/view function and return value
            return {api('call')}({fname}, actual_params, {constructors})
            '''
        else:
            invoke_transact_str = ''

        post_body_code = self.do_if_external(ast, [
            generate_proof_str, invoke_transact_str
        ], [
            f'return {", ".join([f"{cfg.return_var_name}_{idx}" for idx in range(len(ast.return_parameters))])}'
            if ast.is_function and ast.requires_verification
            and ast.return_parameters else None
        ])

        code = '\n\n'.join(s.strip() for s in [
            f'assert not {IS_EXTERNAL_CALL}'
            if not ast.can_be_external else None,
            dedent(preamble_str), pre_body_code, body_code, post_body_code
        ] if s)

        func_ctx_params = []
        if circuit:
            func_ctx_params.append(str(circuit.priv_in_size_trans))
        if ast.is_payable:
            func_ctx_params.append('wei_amount=wei_amount')
        if ast.can_be_external:
            func_ctx_params.append(f'name={fname}')
            code = 'with time_measure("transaction_full", skip=not zk__is_ext):\n' + indent(
                code)
        code = f'with self._function_ctx({", ".join(func_ctx_params)}) as {IS_EXTERNAL_CALL}:\n' + indent(
            code)
        return code
Exemplo n.º 6
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition, ext_circuit: CircuitHelper,
                                     original_params: List[Parameter], requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        priv_args = [p for p in original_params if p.annotated_type.is_cipher()]
        args_backends = OrderedDict.fromkeys([p.annotated_type.type_name.crypto_params for p in priv_args])
        stmts = []

        for crypto_params in args_backends:
            assert crypto_params in int_fct.used_crypto_backends
            # If there are any private arguments with homomorphism 'hom', we need the public key for that crypto backend
            ext_circuit._require_public_key_for_label_at(None, Expression.me_expr(), crypto_params)
        for crypto_params in cfg.all_crypto_params():
            if crypto_params.is_symmetric_cipher():
                if (MeExpr(), crypto_params) in ext_circuit.requested_global_keys or crypto_params in args_backends:
                    # Make sure msg.sender's key pair is available in the circuit
                    stmts += ext_circuit.request_private_key(crypto_params)

        # Verify that out parameter has correct size
        stmts += [RequireStatement(IdentifierExpr(cfg.zk_out_name).dot('length').binop('==', NumberLiteralExpr(ext_circuit.out_size_trans)))]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))

        # Request static public keys
        offset = 0
        key_req_stmts = []
        me_key_idx: Dict[CryptoParams, int] = {}
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]

            tmp_keys = {}
            for crypto_params in int_fct.used_crypto_backends:
                tmp_key_var = Identifier(f'_tmp_key_{crypto_params.identifier_name}')
                key_req_stmts.append(tmp_key_var.decl_var(AnnotatedTypeName.key_type(crypto_params)))
                tmp_keys[crypto_params] = tmp_key_var
            for (key_owner, crypto_params) in keys:
                tmp_key_var = tmp_keys[crypto_params]
                idf, assignment = ext_circuit.request_public_key(crypto_params, key_owner, ext_circuit.get_glob_key_name(key_owner, crypto_params))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Remember me-keys for later use in symmetrically encrypted keys
                if key_owner == MeExpr():
                    assert crypto_params not in me_key_idx
                    me_key_idx[crypto_params] = offset

                # Manually add to circuit inputs
                key_len = crypto_params.key_len
                key_req_stmts.append(in_arr_var.slice(offset, key_len).assign(IdentifierExpr(tmp_key_var.clone()).slice(0, key_len)))
                offset += key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                cipher_payload_len = p.annotated_type.type_name.crypto_params.cipher_payload_len
                assign_stmt = in_arr_var.slice(offset, cipher_payload_len).assign(IdentifierExpr(p.idf.clone()).slice(0, cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cipher_payload_len

        # Populate sender field of parameters encrypted with a symmetric cipher
        copy_stmts = []
        for p in original_params:
            if p.annotated_type.is_cipher():
                c = p.annotated_type.type_name
                assert isinstance(c, CipherText)
                if c.crypto_params.is_symmetric_cipher():
                    sender_key = in_arr_var.index(me_key_idx[c.crypto_params])
                    idf = IdentifierExpr(p.idf.clone()).as_type(p.annotated_type.clone())
                    cipher_payload_len = cfg.get_crypto_params(p.annotated_type.homomorphism).cipher_payload_len
                    lit = ArrayLiteralExpr([idf.clone().index(i) for i in range(cipher_payload_len)] + [sender_key])
                    copy_stmts.append(VariableDeclarationStatement(VariableDeclaration([], p.annotated_type.clone(), p.idf.clone(), 'memory'), lit))
        if copy_stmts:
            param_stmts += [Comment(), Comment('Copy from calldata to memory and set sender field')] + copy_stmts

        # Declare in array
        new_in_array_expr = NewExpr(AnnotatedTypeName(TypeName.dyn_uint_array()), [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(), new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys', key_req_stmts)
        stmts += Comment.comment_wrap_block('Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [in_arr_var.clone(), NumberLiteralExpr(ext_circuit.in_size),
                     IdentifierExpr(cfg.zk_out_name), NumberLiteralExpr(ext_circuit.out_size)]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(deep_copy(vd)) for vd in int_fct.return_var_decls])
            in_call = TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof and not cfg.disable_verification:
            verifier = IdentifierExpr(cfg.get_contract_var_name(ext_circuit.verifier_contract_type.code()))
            verifier_args = [IdentifierExpr(cfg.proof_param_name), IdentifierExpr(cfg.zk_in_name), IdentifierExpr(cfg.zk_out_name)]
            verify = ExpressionStatement(verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(StatementList([Comment('Verify zk proof of execution'), verify], excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls])))

        return Block(stmts)
Exemplo n.º 7
0
    def transform_contract(self, su: SourceUnit, c: ContractDefinition) -> ContractDefinition:
        """
        Transform an entire zkay contract into a public solidity contract.

        This:

        * transforms state variables, function bodies and signatures
        * import verification contracts
        * adds zk_data structs for each function with verification \
          (to store circuit I/O, to bypass solidity stack limit and allow for easy assignment of array variables),
        * creates external wrapper functions for all public functions which require verification
        * adds circuit IO serialization/deserialization code from/to zk_data struct to all functions which require verification.

        :param su: [SIDE EFFECTS] Source unit of which this contract is part of
        :param c: [SIDE EFFECTS] The contract to transform
        :return: The contract itself
        """

        all_fcts = c.constructor_definitions + c.function_definitions

        # Get list of static owner labels for this contract
        global_owners = [Expression.me_expr()]
        for var in c.state_variable_declarations:
            if var.annotated_type.is_address() and (var.is_final or var.is_constant):
                global_owners.append(var.idf)

        # Backup untransformed function bodies
        for fct in all_fcts:
            fct.original_body = deep_copy(fct.body, with_types=True, with_analysis=True)

        # Transform types of normal state variables
        c.state_variable_declarations = self.var_decl_trafo.visit_list(c.state_variable_declarations)

        # Split into functions which require verification and those which don't need a circuit helper
        req_ext_fcts = {}
        new_fcts, new_constr = [], []
        for fct in all_fcts:
            assert isinstance(fct, ConstructorOrFunctionDefinition)
            if fct.requires_verification or fct.requires_verification_when_external:
                self.circuits[fct] = self.create_circuit_helper(fct, global_owners)

            if fct.requires_verification_when_external:
                req_ext_fcts[fct] = fct.parameters[:]
            elif fct.is_constructor:
                new_constr.append(fct)
            else:
                new_fcts.append(fct)

        # Add constant state variables for external contracts and field prime
        field_prime_decl = StateVariableDeclaration(AnnotatedTypeName.uint_all(), ['public', 'constant'],
                                                    Identifier(cfg.field_prime_var_name),
                                                    NumberLiteralExpr(bn128_scalar_field))
        contract_var_decls = self.include_verification_contracts(su, c)
        c.state_variable_declarations = [field_prime_decl, Comment()]\
                                        + Comment.comment_list('Helper Contracts', contract_var_decls)\
                                        + [Comment('User state variables')]\
                                        + c.state_variable_declarations

        # Transform signatures
        for f in all_fcts:
            f.parameters = self.var_decl_trafo.visit_list(f.parameters)
        for f in c.function_definitions:
            f.return_parameters = self.var_decl_trafo.visit_list(f.return_parameters)
            f.return_var_decls = self.var_decl_trafo.visit_list(f.return_var_decls)

        # Transform bodies
        for fct in all_fcts:
            gen = self.circuits.get(fct, None)
            fct.body = ZkayStatementTransformer(gen).visit(fct.body)

        # Transform (internal) functions which require verification (add the necessary additional parameters and boilerplate code)
        fcts_with_verification = [fct for fct in all_fcts if fct.requires_verification]
        compute_transitive_circuit_io_sizes(fcts_with_verification, self.circuits)
        transform_internal_calls(fcts_with_verification, self.circuits)
        for f in fcts_with_verification:
            circuit = self.circuits[f]
            assert circuit.requires_verification()
            if circuit.requires_zk_data_struct():
                # Add zk data struct for f to contract
                zk_data_struct = StructDefinition(Identifier(circuit.zk_data_struct_name), [
                    VariableDeclaration([], AnnotatedTypeName(idf.t), idf.clone(), '')
                    for idf in circuit.output_idfs + circuit.input_idfs
                ])
                c.struct_definitions.append(zk_data_struct)
            self.create_internal_verification_wrapper(f)

        # Create external wrapper functions where necessary
        for f, params in req_ext_fcts.items():
            ext_f, int_f = self.split_into_external_and_internal_fct(f, params, global_owners)
            if ext_f.is_function:
                new_fcts.append(ext_f)
            else:
                new_constr.append(ext_f)
            new_fcts.append(int_f)

        c.constructor_definitions = new_constr
        c.function_definitions = new_fcts
        return c
Exemplo n.º 8
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition,
                                     ext_circuit: CircuitHelper,
                                     original_params: List[Parameter],
                                     requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        has_priv_args = any(
            [p.annotated_type.is_cipher() for p in original_params])
        stmts = []

        if has_priv_args:
            ext_circuit._require_public_key_for_label_at(
                None, Expression.me_expr())
        if cfg.is_symmetric_cipher():
            # Make sure msg.sender's key pair is available in the circuit
            assert any(isinstance(k, MeExpr) for k in ext_circuit.requested_global_keys) \
                   or has_priv_args, "requires verification => both sender keys required"
            stmts += ext_circuit.request_private_key()

        # Verify that out parameter has correct size
        stmts += [
            RequireStatement(
                IdentifierExpr(cfg.zk_out_name).dot('length').binop(
                    '==', NumberLiteralExpr(ext_circuit.out_size_trans)))
        ]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(
            Array(AnnotatedTypeName.uint_all()))

        # Find index of me's public key in requested_global_keys
        glob_me_key_index = -1
        for idx, e in enumerate(ext_circuit.requested_global_keys):
            if isinstance(e, MeExpr):
                glob_me_key_index = idx
                break

        # Request static public keys
        offset = 0
        key_req_stmts = []
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]
            if glob_me_key_index != -1:
                (keys[0], keys[glob_me_key_index]) = (keys[glob_me_key_index],
                                                      keys[0])

            tmp_key_var = Identifier('_tmp_key')
            key_req_stmts.append(
                tmp_key_var.decl_var(AnnotatedTypeName.key_type()))
            for key_owner in keys:
                idf, assignment = ext_circuit.request_public_key(
                    key_owner, ext_circuit.get_glob_key_name(key_owner))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Manually add to circuit inputs
                key_req_stmts.append(
                    in_arr_var.slice(offset, cfg.key_len).assign(
                        IdentifierExpr(tmp_key_var.clone()).slice(
                            0, cfg.key_len)))
                offset += cfg.key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                assign_stmt = in_arr_var.slice(
                    offset, cfg.cipher_payload_len).assign(
                        IdentifierExpr(p.idf.clone()).slice(
                            0, cfg.cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cfg.cipher_payload_len

        if cfg.is_symmetric_cipher():
            # Populate sender field of encrypted parameters
            copy_stmts = []
            for p in original_params:
                if p.annotated_type.is_cipher():
                    sender_key = in_arr_var.index(0)
                    idf = IdentifierExpr(p.idf.clone()).as_type(
                        p.annotated_type.clone())
                    lit = ArrayLiteralExpr([
                        idf.clone().index(i)
                        for i in range(cfg.cipher_payload_len)
                    ] + [sender_key])
                    copy_stmts.append(
                        VariableDeclarationStatement(
                            VariableDeclaration([], p.annotated_type.clone(),
                                                p.idf.clone(), 'memory'), lit))
            if copy_stmts:
                param_stmts += [
                    Comment(),
                    Comment(
                        'Copy from calldata to memory and set sender field')
                ] + copy_stmts

            assert glob_me_key_index != -1, "Symmetric cipher but did not request me key"

        # Declare in array
        new_in_array_expr = NewExpr(
            AnnotatedTypeName(TypeName.dyn_uint_array()),
            [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(),
                                              new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys',
                                            key_req_stmts)
        stmts += Comment.comment_wrap_block(
            'Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(
            IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [
                in_arr_var.clone(),
                NumberLiteralExpr(ext_circuit.in_size),
                IdentifierExpr(cfg.zk_out_name),
                NumberLiteralExpr(ext_circuit.out_size)
            ]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [
                VariableDeclarationStatement(deep_copy(vd))
                for vd in int_fct.return_var_decls
            ])
            in_call = TupleExpr([
                IdentifierExpr(vd.idf.clone())
                for vd in int_fct.return_var_decls
            ]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof:
            verifier = IdentifierExpr(
                cfg.get_contract_var_name(
                    ext_circuit.verifier_contract_type.code()))
            verifier_args = [
                IdentifierExpr(cfg.proof_param_name),
                IdentifierExpr(cfg.zk_in_name),
                IdentifierExpr(cfg.zk_out_name)
            ]
            verify = ExpressionStatement(
                verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(
                StatementList(
                    [Comment('Verify zk proof of execution'), verify],
                    excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(
                ReturnStatement(
                    TupleExpr([
                        IdentifierExpr(vd.idf.clone())
                        for vd in int_fct.return_var_decls
                    ])))

        return Block(stmts)
Exemplo n.º 9
0
 def visitVariableDeclaration(self, ast: VariableDeclaration):
     ast.names = {ast.idf.name: ast.idf}
Exemplo n.º 10
0
 def visitIndexExpr(self, ast: IndexExpr):
     assert isinstance(ast.arr, LocationExpr), "Function call return value indexing not yet supported"
     source_t = ast.arr.target.annotated_type.type_name
     ast.target = VariableDeclaration([], source_t.value_type, Identifier(''))