示例#1
0
    def visitStatementList(self, ast: StatementList):
        """
        Rule (1)

        All statements are transformed individually.
        Whenever the transformation of a statement requires the introduction of additional statements
        (the CircuitHelper indicates this by storing them in the statement's pre_statements list), they are prepended to the transformed
        statement in the list.

        If transformation changes the appearance of a statement (apart from type changes),
        the statement is wrapped in a comment block which displays the original statement's code.
        """
        new_statements = []
        for idx, stmt in enumerate(ast.statements):
            old_code = stmt.code()
            transformed_stmt = self.visit(stmt)
            if transformed_stmt is None:
                continue

            old_code_wo_annotations = re.sub(r'(?=\b)me(?=\b)', 'msg.sender',
                                             re.sub(f'@{WS_PATTERN}*{ID_PATTERN}', '', old_code))
            new_code_wo_annotation_comments = re.sub(r'/\*.*?\*/', '', transformed_stmt.code())
            if old_code_wo_annotations == new_code_wo_annotation_comments:
                new_statements.append(transformed_stmt)
            else:
                new_statements += Comment.comment_wrap_block(old_code, transformed_stmt.pre_statements + [transformed_stmt])

        if new_statements and isinstance(new_statements[-1], BlankLine):
            new_statements = new_statements[:-1]
        ast.statements = new_statements
        return ast
示例#2
0
 def visitStatementList(self, ast: StatementList):
     ast.after_analysis = self.propagate(ast.statements,
                                         ast.before_analysis)
示例#3
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition, ext_circuit: CircuitHelper,
                                     original_params: List[Parameter], requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        priv_args = [p for p in original_params if p.annotated_type.is_cipher()]
        args_backends = OrderedDict.fromkeys([p.annotated_type.type_name.crypto_params for p in priv_args])
        stmts = []

        for crypto_params in args_backends:
            assert crypto_params in int_fct.used_crypto_backends
            # If there are any private arguments with homomorphism 'hom', we need the public key for that crypto backend
            ext_circuit._require_public_key_for_label_at(None, Expression.me_expr(), crypto_params)
        for crypto_params in cfg.all_crypto_params():
            if crypto_params.is_symmetric_cipher():
                if (MeExpr(), crypto_params) in ext_circuit.requested_global_keys or crypto_params in args_backends:
                    # Make sure msg.sender's key pair is available in the circuit
                    stmts += ext_circuit.request_private_key(crypto_params)

        # Verify that out parameter has correct size
        stmts += [RequireStatement(IdentifierExpr(cfg.zk_out_name).dot('length').binop('==', NumberLiteralExpr(ext_circuit.out_size_trans)))]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))

        # Request static public keys
        offset = 0
        key_req_stmts = []
        me_key_idx: Dict[CryptoParams, int] = {}
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]

            tmp_keys = {}
            for crypto_params in int_fct.used_crypto_backends:
                tmp_key_var = Identifier(f'_tmp_key_{crypto_params.identifier_name}')
                key_req_stmts.append(tmp_key_var.decl_var(AnnotatedTypeName.key_type(crypto_params)))
                tmp_keys[crypto_params] = tmp_key_var
            for (key_owner, crypto_params) in keys:
                tmp_key_var = tmp_keys[crypto_params]
                idf, assignment = ext_circuit.request_public_key(crypto_params, key_owner, ext_circuit.get_glob_key_name(key_owner, crypto_params))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Remember me-keys for later use in symmetrically encrypted keys
                if key_owner == MeExpr():
                    assert crypto_params not in me_key_idx
                    me_key_idx[crypto_params] = offset

                # Manually add to circuit inputs
                key_len = crypto_params.key_len
                key_req_stmts.append(in_arr_var.slice(offset, key_len).assign(IdentifierExpr(tmp_key_var.clone()).slice(0, key_len)))
                offset += key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                cipher_payload_len = p.annotated_type.type_name.crypto_params.cipher_payload_len
                assign_stmt = in_arr_var.slice(offset, cipher_payload_len).assign(IdentifierExpr(p.idf.clone()).slice(0, cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cipher_payload_len

        # Populate sender field of parameters encrypted with a symmetric cipher
        copy_stmts = []
        for p in original_params:
            if p.annotated_type.is_cipher():
                c = p.annotated_type.type_name
                assert isinstance(c, CipherText)
                if c.crypto_params.is_symmetric_cipher():
                    sender_key = in_arr_var.index(me_key_idx[c.crypto_params])
                    idf = IdentifierExpr(p.idf.clone()).as_type(p.annotated_type.clone())
                    cipher_payload_len = cfg.get_crypto_params(p.annotated_type.homomorphism).cipher_payload_len
                    lit = ArrayLiteralExpr([idf.clone().index(i) for i in range(cipher_payload_len)] + [sender_key])
                    copy_stmts.append(VariableDeclarationStatement(VariableDeclaration([], p.annotated_type.clone(), p.idf.clone(), 'memory'), lit))
        if copy_stmts:
            param_stmts += [Comment(), Comment('Copy from calldata to memory and set sender field')] + copy_stmts

        # Declare in array
        new_in_array_expr = NewExpr(AnnotatedTypeName(TypeName.dyn_uint_array()), [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(), new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys', key_req_stmts)
        stmts += Comment.comment_wrap_block('Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [in_arr_var.clone(), NumberLiteralExpr(ext_circuit.in_size),
                     IdentifierExpr(cfg.zk_out_name), NumberLiteralExpr(ext_circuit.out_size)]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(deep_copy(vd)) for vd in int_fct.return_var_decls])
            in_call = TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof and not cfg.disable_verification:
            verifier = IdentifierExpr(cfg.get_contract_var_name(ext_circuit.verifier_contract_type.code()))
            verifier_args = [IdentifierExpr(cfg.proof_param_name), IdentifierExpr(cfg.zk_in_name), IdentifierExpr(cfg.zk_out_name)]
            verify = ExpressionStatement(verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(StatementList([Comment('Verify zk proof of execution'), verify], excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls])))

        return Block(stmts)
示例#4
0
    def create_internal_verification_wrapper(self, ast: ConstructorOrFunctionDefinition):
        """
        Add the necessary additional parameters and boiler plate code for verification support to the given function.

        :param ast: [SIDE EFFECT] Internal function which requires verification
        """
        circuit = self.circuits[ast]
        stmts = []

        symmetric_cipher_used = any([backend.is_symmetric_cipher() for backend in ast.used_crypto_backends])
        if symmetric_cipher_used and 'pure' in ast.modifiers:
            # Symmetric trafo requires msg.sender access -> change from pure to view
            ast.modifiers = ['view' if mod == 'pure' else mod for mod in ast.modifiers]

        # Add additional params
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_in_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_in_name}_start_idx')
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_out_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_out_name}_start_idx')

        # Verify that in/out parameters have correct size
        out_start_idx, in_start_idx = IdentifierExpr(f'{cfg.zk_out_name}_start_idx'), IdentifierExpr(f'{cfg.zk_in_name}_start_idx')
        out_var, in_var = IdentifierExpr(cfg.zk_out_name), IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))
        stmts.append(RequireStatement(out_start_idx.binop('+', NumberLiteralExpr(circuit.out_size_trans)).binop('<=', out_var.dot('length'))))
        stmts.append(RequireStatement(in_start_idx.binop('+', NumberLiteralExpr(circuit.in_size_trans)).binop('<=', in_var.dot('length'))))

        # Declare zk_data struct var (if needed)
        if circuit.requires_zk_data_struct():
            zk_struct_type = StructTypeName([Identifier(circuit.zk_data_struct_name)])
            stmts += [Identifier(cfg.zk_data_var_name).decl_var(zk_struct_type), BlankLine()]

        # Declare return variable if necessary
        if ast.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(vd) for vd in ast.return_var_decls])

        # Find all me-keys in the in array
        me_key_idx: Dict[CryptoParams, int] = {}
        offset = 0
        for (key_owner, crypto_params) in circuit.requested_global_keys:
            if key_owner == MeExpr():
                assert crypto_params not in me_key_idx
                me_key_idx[crypto_params] = offset
            offset += crypto_params.key_len

        # Deserialize out array (if any)
        deserialize_stmts = []
        offset = 0
        for s in circuit.output_idfs:
            deserialize_stmts.append(s.deserialize(cfg.zk_out_name, out_start_idx, offset))
            if isinstance(s.t, CipherText) and s.t.crypto_params.is_symmetric_cipher():
                # Assign sender field to user-encrypted values if necessary
                # Assumption: s.t.crypto_params.key_len == 1 for all symmetric ciphers
                assert s.t.crypto_params in me_key_idx, "Symmetric cipher but did not request me key"
                key_idx = me_key_idx[s.t.crypto_params]
                sender_key = in_var.index(key_idx)
                cipher_payload_len = s.t.crypto_params.cipher_payload_len
                deserialize_stmts.append(s.get_loc_expr().index(cipher_payload_len).assign(sender_key))
            offset += s.t.size_in_uints
        if deserialize_stmts:
            stmts.append(StatementList(Comment.comment_wrap_block("Deserialize output values", deserialize_stmts), excluded_from_simulation=True))

        # Include original transformed function body
        stmts += ast.body.statements

        # Serialize in parameters to in array (if any)
        serialize_stmts = []
        offset = 0
        for s in circuit.input_idfs:
            serialize_stmts += [s.serialize(cfg.zk_in_name, in_start_idx, offset)]
            offset += s.t.size_in_uints
        if offset:
            stmts.append(Comment())
            stmts += Comment.comment_wrap_block('Serialize input values', serialize_stmts)

        # Add return statement at the end if necessary
        # (was previously replaced by assignment to return_var by ZkayStatementTransformer)
        if circuit.has_return_var:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()).override(target=vd) for vd in ast.return_var_decls])))

        ast.body.statements[:] = stmts
示例#5
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition,
                                     ext_circuit: CircuitHelper,
                                     original_params: List[Parameter],
                                     requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        has_priv_args = any(
            [p.annotated_type.is_cipher() for p in original_params])
        stmts = []

        if has_priv_args:
            ext_circuit._require_public_key_for_label_at(
                None, Expression.me_expr())
        if cfg.is_symmetric_cipher():
            # Make sure msg.sender's key pair is available in the circuit
            assert any(isinstance(k, MeExpr) for k in ext_circuit.requested_global_keys) \
                   or has_priv_args, "requires verification => both sender keys required"
            stmts += ext_circuit.request_private_key()

        # Verify that out parameter has correct size
        stmts += [
            RequireStatement(
                IdentifierExpr(cfg.zk_out_name).dot('length').binop(
                    '==', NumberLiteralExpr(ext_circuit.out_size_trans)))
        ]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(
            Array(AnnotatedTypeName.uint_all()))

        # Find index of me's public key in requested_global_keys
        glob_me_key_index = -1
        for idx, e in enumerate(ext_circuit.requested_global_keys):
            if isinstance(e, MeExpr):
                glob_me_key_index = idx
                break

        # Request static public keys
        offset = 0
        key_req_stmts = []
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]
            if glob_me_key_index != -1:
                (keys[0], keys[glob_me_key_index]) = (keys[glob_me_key_index],
                                                      keys[0])

            tmp_key_var = Identifier('_tmp_key')
            key_req_stmts.append(
                tmp_key_var.decl_var(AnnotatedTypeName.key_type()))
            for key_owner in keys:
                idf, assignment = ext_circuit.request_public_key(
                    key_owner, ext_circuit.get_glob_key_name(key_owner))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Manually add to circuit inputs
                key_req_stmts.append(
                    in_arr_var.slice(offset, cfg.key_len).assign(
                        IdentifierExpr(tmp_key_var.clone()).slice(
                            0, cfg.key_len)))
                offset += cfg.key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                assign_stmt = in_arr_var.slice(
                    offset, cfg.cipher_payload_len).assign(
                        IdentifierExpr(p.idf.clone()).slice(
                            0, cfg.cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cfg.cipher_payload_len

        if cfg.is_symmetric_cipher():
            # Populate sender field of encrypted parameters
            copy_stmts = []
            for p in original_params:
                if p.annotated_type.is_cipher():
                    sender_key = in_arr_var.index(0)
                    idf = IdentifierExpr(p.idf.clone()).as_type(
                        p.annotated_type.clone())
                    lit = ArrayLiteralExpr([
                        idf.clone().index(i)
                        for i in range(cfg.cipher_payload_len)
                    ] + [sender_key])
                    copy_stmts.append(
                        VariableDeclarationStatement(
                            VariableDeclaration([], p.annotated_type.clone(),
                                                p.idf.clone(), 'memory'), lit))
            if copy_stmts:
                param_stmts += [
                    Comment(),
                    Comment(
                        'Copy from calldata to memory and set sender field')
                ] + copy_stmts

            assert glob_me_key_index != -1, "Symmetric cipher but did not request me key"

        # Declare in array
        new_in_array_expr = NewExpr(
            AnnotatedTypeName(TypeName.dyn_uint_array()),
            [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(),
                                              new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys',
                                            key_req_stmts)
        stmts += Comment.comment_wrap_block(
            'Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(
            IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [
                in_arr_var.clone(),
                NumberLiteralExpr(ext_circuit.in_size),
                IdentifierExpr(cfg.zk_out_name),
                NumberLiteralExpr(ext_circuit.out_size)
            ]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [
                VariableDeclarationStatement(deep_copy(vd))
                for vd in int_fct.return_var_decls
            ])
            in_call = TupleExpr([
                IdentifierExpr(vd.idf.clone())
                for vd in int_fct.return_var_decls
            ]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof:
            verifier = IdentifierExpr(
                cfg.get_contract_var_name(
                    ext_circuit.verifier_contract_type.code()))
            verifier_args = [
                IdentifierExpr(cfg.proof_param_name),
                IdentifierExpr(cfg.zk_in_name),
                IdentifierExpr(cfg.zk_out_name)
            ]
            verify = ExpressionStatement(
                verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(
                StatementList(
                    [Comment('Verify zk proof of execution'), verify],
                    excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(
                ReturnStatement(
                    TupleExpr([
                        IdentifierExpr(vd.idf.clone())
                        for vd in int_fct.return_var_decls
                    ])))

        return Block(stmts)
示例#6
0
 def visitStatementList(self, ast: StatementList):
     ast.names = collect_children_names(ast)