Exemplo n.º 1
0
    def create_contract_variable(cname: str) -> StateVariableDeclaration:
        """Create a public constant state variable with which contract with name 'cname' can be accessed"""
        inst_idf = Identifier(cfg.get_contract_var_name(cname))
        c_type = ContractTypeName([Identifier(cname)])

        cast_0_to_c = PrimitiveCastExpr(c_type, NumberLiteralExpr(0))
        var_decl = StateVariableDeclaration(AnnotatedTypeName(c_type), ['public', 'constant'], inst_idf.clone(), cast_0_to_c)
        return var_decl
Exemplo n.º 2
0
    def split_into_external_and_internal_fct(self, f: ConstructorOrFunctionDefinition, original_params: List[Parameter],
                                             global_owners: List[PrivacyLabelExpr]) -> Tuple[ConstructorOrFunctionDefinition,
                                                                                             ConstructorOrFunctionDefinition]:
        """
        Take public function f and split it into an internal function and an external wrapper function.

        :param f: [SIDE EFFECT] function to split (at least requires_verification_if_external)
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :param global_owners: list of static labels (me + final address state variable identifiers)
        :return: Tuple of newly created external and internal function definitions
        """
        assert f.requires_verification_when_external

        # Create new empty function with same parameters as original -> external wrapper
        if f.is_function:
            new_modifiers = ['external']
            original_params = [deep_copy(p, with_types=True).with_changed_storage('memory', 'calldata') for p in original_params]
        else:
            new_modifiers = ['public']
        if f.is_payable:
            new_modifiers.append('payable')

        requires_proof = True
        if not f.has_side_effects:
            requires_proof = False
            new_modifiers.append('view')
        new_f = ConstructorOrFunctionDefinition(f.idf, original_params, new_modifiers, f.return_parameters, Block([]))

        # Make original function internal
        f.idf = Identifier(cfg.get_internal_name(f))
        f.modifiers = ['internal' if mod == 'public' else mod for mod in f.modifiers if mod != 'payable']
        f.requires_verification_when_external = False

        # Create new circuit for external function
        circuit = self.create_circuit_helper(new_f, global_owners, self.circuits[f])
        if not f.requires_verification:
            del self.circuits[f]
        self.circuits[new_f] = circuit

        # Set meta attributes and populate body
        new_f.requires_verification = True
        new_f.requires_verification_when_external = True
        new_f.called_functions = f.called_functions
        new_f.called_functions[f] = None
        new_f.used_crypto_backends = f.used_crypto_backends
        new_f.body = self.create_external_wrapper_body(f, circuit, original_params, requires_proof)

        # Add out and proof parameter to external wrapper
        storage_loc = 'calldata' if new_f.is_function else 'memory'
        new_f.add_param(Array(AnnotatedTypeName.uint_all()), Identifier(cfg.zk_out_name), storage_loc)

        if requires_proof:
            new_f.add_param(AnnotatedTypeName.proof_type(), Identifier(cfg.proof_param_name), storage_loc)

        return new_f, f
Exemplo n.º 3
0
 def test_assignment_statement(self):
     i = Identifier('x')
     lhs = IdentifierExpr(i)
     rhs = BooleanLiteralExpr(True)
     a = AssignmentStatement(lhs, rhs)
     self.assertIsNotNone(a)
     self.assertEqual(str(a), 'x = true;')
     self.assertEqual(a.children(), [lhs, rhs])
     self.assertDictEqual(a.names, {})
     self.assertIsNone(a.parent)
Exemplo n.º 4
0
    def import_contract(cname: str, su: SourceUnit, corresponding_circuit: Optional[CircuitHelper] = None):
        """
        Import contract 'vname' into the given source unit.

        :param cname: contract name (.sol filename stem must match contract type name)
        :param su: [SIDE EFFECT] source unit where contract should be imported
        :param corresponding_circuit: [SIDE EFFECT] if contract is a verification contract, this should be the corresponding circuit helper
        """
        import_filename = f'./{cname}.sol'
        su.used_contracts.append(import_filename)

        if corresponding_circuit is not None:
            c_type = ContractTypeName([Identifier(cname)])
            corresponding_circuit.register_verification_contract_metadata(c_type, import_filename)
Exemplo n.º 5
0
class GlobalVars:
    msg: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.all(
            StructTypeName([GlobalDefs.msg_struct.idf],
                           GlobalDefs.msg_struct)), [], Identifier('msg'),
        None)
    msg.idf.parent = msg

    block: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.all(
            StructTypeName([GlobalDefs.block_struct.idf],
                           GlobalDefs.block_struct)), [], Identifier('block'),
        None)
    block.idf.parent = block

    tx: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.all(
            StructTypeName([GlobalDefs.tx_struct.idf], GlobalDefs.tx_struct)),
        [], Identifier('tx'), None)
    tx.idf.parent = tx

    now: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.uint_all(), [], Identifier('now'), None)
    now.idf.parent = now
Exemplo n.º 6
0
 def get_loc_value(self, arr: Identifier, indices: List[str]) -> str:
     """Get the location of the given identifier/array element."""
     if isinstance(
             arr, HybridArgumentIdf
     ) and arr.arg_type == HybridArgType.PRIV_CIRCUIT_VAL and not arr.name.startswith(
             'tmp'):
         # Private circuit values are located in private value dictionary
         return self.get_priv_value(arr.name)
     elif isinstance(arr, HybridArgumentIdf
                     ) and arr.arg_type == HybridArgType.PUB_CIRCUIT_ARG:
         # Public circuit inputs are in the zk_data dict
         return self.visit(arr.get_loc_expr())
     else:
         idxvals = ''.join([f'[{idx}]' for idx in indices])
         return f'{self.visit(arr)}{idxvals}'
Exemplo n.º 7
0
    def get_value(self, idf: IdentifierExpr, indices: List[str]):
        """
        Get code corresponding to the rvalue location of an identifier or index expression.

        e.g. idf = x and indices = [some_addr, 5] corresponds to x[some_addr][5]
        State variable values are downloaded from the chain if their value is not yet present in the local state variable dict.
        """
        if self.is_special_var(idf.idf):
            return self.get_loc_value(idf.idf, indices)
        elif isinstance(idf.target, StateVariableDeclaration):
            # If a state variable appears as an rvalue, the value may need to be requested from the blockchain
            indices = f', {", ".join(indices)}' if indices else ''
            return f'self.state["{idf.idf.name}"{indices}]'
        else:
            name = idf.idf
            if isinstance(idf.target,
                          VariableDeclaration) and not self.inside_circuit:
                # Local variables are stored in locals dict
                name = Identifier(f'self.locals["{idf.idf.name}"]')
            return self.get_loc_value(name, indices)
Exemplo n.º 8
0
    def visitMemberAccessExpr(self, ast: MemberAccessExpr):
        assert not isinstance(
            ast.target,
            StateVariableDeclaration), "State member accesses not handled"

        if isinstance(ast.expr.target, EnumDefinition):
            return f'{self.visit_list(ast.expr.target.qualified_name, sep=".")}.{self.visit(ast.member)}'

        if ast.member.name == 'length' and isinstance(
                ast.expr.target.annotated_type.type_name, Array):
            return f'len({self.visit(ast.expr)})'

        if self.current_index:
            indices = list(reversed(self.current_index))
            self.current_index, self.current_index_t = [], None
            indices = [self.visit(idx) for idx in indices]
        else:
            indices = []

        if isinstance(ast.member, HybridArgumentIdf):
            e = f'{self.visit(ast.expr)}["{ast.member.name}"]'
        else:
            e = super().visitMemberAccessExpr(ast)
        return self.get_loc_value(Identifier(e), indices)
Exemplo n.º 9
0
class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
                Identifier('send'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'],
                [Parameter([], AnnotatedTypeName.bool_all(), Identifier(''))],
                Block([])),
            ConstructorOrFunctionDefinition(
                Identifier('transfer'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'], [], Block([])),
        ])
    address_payable_struct.members[1].can_be_private = False
    address_payable_struct.members[2].can_be_private = False
    set_parents(address_payable_struct)

    msg_struct: StructDefinition = StructDefinition(Identifier('<msg>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('sender')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('value')),
    ])
    set_parents(msg_struct)

    block_struct: StructDefinition = StructDefinition(Identifier('<block>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('coinbase')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('difficulty')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gaslimit')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('number')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('timestamp')),
    ])
    set_parents(block_struct)

    tx_struct: StructDefinition = StructDefinition(Identifier('<tx>'), [
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gasprice')),
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('origin')),
    ])
    set_parents(tx_struct)
Exemplo n.º 10
0
# BUILTIN SPECIAL TYPE DEFINITIONS
from zkay.zkay_ast.ast import AnnotatedTypeName, FunctionTypeName, Parameter, Identifier, StructDefinition, \
    VariableDeclaration, TypeName, StateVariableDeclaration, UserDefinedTypeName, StructTypeName, Block, ConstructorOrFunctionDefinition
from zkay.zkay_ast.pointers.parent_setter import set_parents

array_length_member = VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                          Identifier('length'))


class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
Exemplo n.º 11
0
    def evaluate_stmt_in_circuit(self, ast: Statement) -> AssignmentStatement:
        """
        Evaluate an entire statement privately.

        This works by turning the statement into an assignment statement where the

        * lhs is a tuple of all external locations (defined outside statement), which are modified inside the statement
        * rhs is the return value of an inlined function call expression to a virtual function where body = the statement + return statement \
          which returns a tuple of the most recent SSA version of all modified locations

        Note: Modifying external locations which are not owned by @me inside the statement is illegal (would leak information).
        Note: At the moment, this is only used for if statements with a private condition.

        :param ast: the statement to evaluate inside the circuit
        :return: AssignmentStatement as described above
        """
        astmt = ExpressionStatement(NumberLiteralExpr(0))
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                astmt = AssignmentStatement(None, None)
                break

        astmt.before_analysis = ast.before_analysis

        # External values written inside statement -> function return values
        ret_params = []
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                # side effect affects location outside statement and has privacy @me
                assert ast.before_analysis.same_partition(
                    var.privacy, Expression.me_expr())
                assert isinstance(
                    var.target,
                    (Parameter, VariableDeclaration, StateVariableDeclaration))
                t = var.target.annotated_type.zkay_type
                if not t.type_name.is_primitive_type():
                    raise NotImplementedError(
                        'Reference types inside private if statements are not supported'
                    )
                ret_t = AnnotatedTypeName(t.type_name, Expression.me_expr(),
                                          t.homomorphism)  # t, but @me
                ret_param = IdentifierExpr(var.target.idf.clone(),
                                           ret_t).override(target=var.target)
                ret_param.statement = astmt
                ret_params.append(ret_param)

        # Build the imaginary function
        fdef = ConstructorOrFunctionDefinition(
            Identifier('<stmt_fct>'), [], ['private'], [
                Parameter([], ret.annotated_type, ret.target.idf)
                for ret in ret_params
            ], Block([ast, ReturnStatement(TupleExpr(ret_params))]))
        fdef.original_body = fdef.body
        fdef.body.parent = fdef
        fdef.parent = ast

        # inline "Call" to the imaginary function
        fcall = FunctionCallExpr(
            IdentifierExpr('<stmt_fct>').override(target=fdef), [])
        fcall.statement = astmt
        ret_args = self.inline_function_call_into_circuit(fcall)

        # Move all return values out of the circuit
        if not isinstance(ret_args, TupleExpr):
            ret_args = TupleExpr([ret_args])
        for ret_arg in ret_args.elements:
            ret_arg.statement = astmt
        ret_arg_outs = [
            self._get_circuit_output_for_private_expression(
                ret_arg, Expression.me_expr(),
                ret_param.annotated_type.homomorphism)
            for ret_param, ret_arg in zip(ret_params, ret_args.elements)
        ]

        # Create assignment statement
        if ret_params:
            astmt.lhs = TupleExpr(
                [ret_param.clone() for ret_param in ret_params])
            astmt.rhs = TupleExpr(ret_arg_outs)
            return astmt
        else:
            assert isinstance(astmt, ExpressionStatement)
            return astmt
Exemplo n.º 12
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition, ext_circuit: CircuitHelper,
                                     original_params: List[Parameter], requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        priv_args = [p for p in original_params if p.annotated_type.is_cipher()]
        args_backends = OrderedDict.fromkeys([p.annotated_type.type_name.crypto_params for p in priv_args])
        stmts = []

        for crypto_params in args_backends:
            assert crypto_params in int_fct.used_crypto_backends
            # If there are any private arguments with homomorphism 'hom', we need the public key for that crypto backend
            ext_circuit._require_public_key_for_label_at(None, Expression.me_expr(), crypto_params)
        for crypto_params in cfg.all_crypto_params():
            if crypto_params.is_symmetric_cipher():
                if (MeExpr(), crypto_params) in ext_circuit.requested_global_keys or crypto_params in args_backends:
                    # Make sure msg.sender's key pair is available in the circuit
                    stmts += ext_circuit.request_private_key(crypto_params)

        # Verify that out parameter has correct size
        stmts += [RequireStatement(IdentifierExpr(cfg.zk_out_name).dot('length').binop('==', NumberLiteralExpr(ext_circuit.out_size_trans)))]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))

        # Request static public keys
        offset = 0
        key_req_stmts = []
        me_key_idx: Dict[CryptoParams, int] = {}
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]

            tmp_keys = {}
            for crypto_params in int_fct.used_crypto_backends:
                tmp_key_var = Identifier(f'_tmp_key_{crypto_params.identifier_name}')
                key_req_stmts.append(tmp_key_var.decl_var(AnnotatedTypeName.key_type(crypto_params)))
                tmp_keys[crypto_params] = tmp_key_var
            for (key_owner, crypto_params) in keys:
                tmp_key_var = tmp_keys[crypto_params]
                idf, assignment = ext_circuit.request_public_key(crypto_params, key_owner, ext_circuit.get_glob_key_name(key_owner, crypto_params))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Remember me-keys for later use in symmetrically encrypted keys
                if key_owner == MeExpr():
                    assert crypto_params not in me_key_idx
                    me_key_idx[crypto_params] = offset

                # Manually add to circuit inputs
                key_len = crypto_params.key_len
                key_req_stmts.append(in_arr_var.slice(offset, key_len).assign(IdentifierExpr(tmp_key_var.clone()).slice(0, key_len)))
                offset += key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                cipher_payload_len = p.annotated_type.type_name.crypto_params.cipher_payload_len
                assign_stmt = in_arr_var.slice(offset, cipher_payload_len).assign(IdentifierExpr(p.idf.clone()).slice(0, cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cipher_payload_len

        # Populate sender field of parameters encrypted with a symmetric cipher
        copy_stmts = []
        for p in original_params:
            if p.annotated_type.is_cipher():
                c = p.annotated_type.type_name
                assert isinstance(c, CipherText)
                if c.crypto_params.is_symmetric_cipher():
                    sender_key = in_arr_var.index(me_key_idx[c.crypto_params])
                    idf = IdentifierExpr(p.idf.clone()).as_type(p.annotated_type.clone())
                    cipher_payload_len = cfg.get_crypto_params(p.annotated_type.homomorphism).cipher_payload_len
                    lit = ArrayLiteralExpr([idf.clone().index(i) for i in range(cipher_payload_len)] + [sender_key])
                    copy_stmts.append(VariableDeclarationStatement(VariableDeclaration([], p.annotated_type.clone(), p.idf.clone(), 'memory'), lit))
        if copy_stmts:
            param_stmts += [Comment(), Comment('Copy from calldata to memory and set sender field')] + copy_stmts

        # Declare in array
        new_in_array_expr = NewExpr(AnnotatedTypeName(TypeName.dyn_uint_array()), [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(), new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys', key_req_stmts)
        stmts += Comment.comment_wrap_block('Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [in_arr_var.clone(), NumberLiteralExpr(ext_circuit.in_size),
                     IdentifierExpr(cfg.zk_out_name), NumberLiteralExpr(ext_circuit.out_size)]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(deep_copy(vd)) for vd in int_fct.return_var_decls])
            in_call = TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof and not cfg.disable_verification:
            verifier = IdentifierExpr(cfg.get_contract_var_name(ext_circuit.verifier_contract_type.code()))
            verifier_args = [IdentifierExpr(cfg.proof_param_name), IdentifierExpr(cfg.zk_in_name), IdentifierExpr(cfg.zk_out_name)]
            verify = ExpressionStatement(verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(StatementList([Comment('Verify zk proof of execution'), verify], excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls])))

        return Block(stmts)
Exemplo n.º 13
0
    def create_internal_verification_wrapper(self, ast: ConstructorOrFunctionDefinition):
        """
        Add the necessary additional parameters and boiler plate code for verification support to the given function.

        :param ast: [SIDE EFFECT] Internal function which requires verification
        """
        circuit = self.circuits[ast]
        stmts = []

        symmetric_cipher_used = any([backend.is_symmetric_cipher() for backend in ast.used_crypto_backends])
        if symmetric_cipher_used and 'pure' in ast.modifiers:
            # Symmetric trafo requires msg.sender access -> change from pure to view
            ast.modifiers = ['view' if mod == 'pure' else mod for mod in ast.modifiers]

        # Add additional params
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_in_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_in_name}_start_idx')
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_out_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_out_name}_start_idx')

        # Verify that in/out parameters have correct size
        out_start_idx, in_start_idx = IdentifierExpr(f'{cfg.zk_out_name}_start_idx'), IdentifierExpr(f'{cfg.zk_in_name}_start_idx')
        out_var, in_var = IdentifierExpr(cfg.zk_out_name), IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))
        stmts.append(RequireStatement(out_start_idx.binop('+', NumberLiteralExpr(circuit.out_size_trans)).binop('<=', out_var.dot('length'))))
        stmts.append(RequireStatement(in_start_idx.binop('+', NumberLiteralExpr(circuit.in_size_trans)).binop('<=', in_var.dot('length'))))

        # Declare zk_data struct var (if needed)
        if circuit.requires_zk_data_struct():
            zk_struct_type = StructTypeName([Identifier(circuit.zk_data_struct_name)])
            stmts += [Identifier(cfg.zk_data_var_name).decl_var(zk_struct_type), BlankLine()]

        # Declare return variable if necessary
        if ast.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(vd) for vd in ast.return_var_decls])

        # Find all me-keys in the in array
        me_key_idx: Dict[CryptoParams, int] = {}
        offset = 0
        for (key_owner, crypto_params) in circuit.requested_global_keys:
            if key_owner == MeExpr():
                assert crypto_params not in me_key_idx
                me_key_idx[crypto_params] = offset
            offset += crypto_params.key_len

        # Deserialize out array (if any)
        deserialize_stmts = []
        offset = 0
        for s in circuit.output_idfs:
            deserialize_stmts.append(s.deserialize(cfg.zk_out_name, out_start_idx, offset))
            if isinstance(s.t, CipherText) and s.t.crypto_params.is_symmetric_cipher():
                # Assign sender field to user-encrypted values if necessary
                # Assumption: s.t.crypto_params.key_len == 1 for all symmetric ciphers
                assert s.t.crypto_params in me_key_idx, "Symmetric cipher but did not request me key"
                key_idx = me_key_idx[s.t.crypto_params]
                sender_key = in_var.index(key_idx)
                cipher_payload_len = s.t.crypto_params.cipher_payload_len
                deserialize_stmts.append(s.get_loc_expr().index(cipher_payload_len).assign(sender_key))
            offset += s.t.size_in_uints
        if deserialize_stmts:
            stmts.append(StatementList(Comment.comment_wrap_block("Deserialize output values", deserialize_stmts), excluded_from_simulation=True))

        # Include original transformed function body
        stmts += ast.body.statements

        # Serialize in parameters to in array (if any)
        serialize_stmts = []
        offset = 0
        for s in circuit.input_idfs:
            serialize_stmts += [s.serialize(cfg.zk_in_name, in_start_idx, offset)]
            offset += s.t.size_in_uints
        if offset:
            stmts.append(Comment())
            stmts += Comment.comment_wrap_block('Serialize input values', serialize_stmts)

        # Add return statement at the end if necessary
        # (was previously replaced by assignment to return_var by ZkayStatementTransformer)
        if circuit.has_return_var:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()).override(target=vd) for vd in ast.return_var_decls])))

        ast.body.statements[:] = stmts
Exemplo n.º 14
0
    def transform_contract(self, su: SourceUnit, c: ContractDefinition) -> ContractDefinition:
        """
        Transform an entire zkay contract into a public solidity contract.

        This:

        * transforms state variables, function bodies and signatures
        * import verification contracts
        * adds zk_data structs for each function with verification \
          (to store circuit I/O, to bypass solidity stack limit and allow for easy assignment of array variables),
        * creates external wrapper functions for all public functions which require verification
        * adds circuit IO serialization/deserialization code from/to zk_data struct to all functions which require verification.

        :param su: [SIDE EFFECTS] Source unit of which this contract is part of
        :param c: [SIDE EFFECTS] The contract to transform
        :return: The contract itself
        """

        all_fcts = c.constructor_definitions + c.function_definitions

        # Get list of static owner labels for this contract
        global_owners = [Expression.me_expr()]
        for var in c.state_variable_declarations:
            if var.annotated_type.is_address() and (var.is_final or var.is_constant):
                global_owners.append(var.idf)

        # Backup untransformed function bodies
        for fct in all_fcts:
            fct.original_body = deep_copy(fct.body, with_types=True, with_analysis=True)

        # Transform types of normal state variables
        c.state_variable_declarations = self.var_decl_trafo.visit_list(c.state_variable_declarations)

        # Split into functions which require verification and those which don't need a circuit helper
        req_ext_fcts = {}
        new_fcts, new_constr = [], []
        for fct in all_fcts:
            assert isinstance(fct, ConstructorOrFunctionDefinition)
            if fct.requires_verification or fct.requires_verification_when_external:
                self.circuits[fct] = self.create_circuit_helper(fct, global_owners)

            if fct.requires_verification_when_external:
                req_ext_fcts[fct] = fct.parameters[:]
            elif fct.is_constructor:
                new_constr.append(fct)
            else:
                new_fcts.append(fct)

        # Add constant state variables for external contracts and field prime
        field_prime_decl = StateVariableDeclaration(AnnotatedTypeName.uint_all(), ['public', 'constant'],
                                                    Identifier(cfg.field_prime_var_name),
                                                    NumberLiteralExpr(bn128_scalar_field))
        contract_var_decls = self.include_verification_contracts(su, c)
        c.state_variable_declarations = [field_prime_decl, Comment()]\
                                        + Comment.comment_list('Helper Contracts', contract_var_decls)\
                                        + [Comment('User state variables')]\
                                        + c.state_variable_declarations

        # Transform signatures
        for f in all_fcts:
            f.parameters = self.var_decl_trafo.visit_list(f.parameters)
        for f in c.function_definitions:
            f.return_parameters = self.var_decl_trafo.visit_list(f.return_parameters)
            f.return_var_decls = self.var_decl_trafo.visit_list(f.return_var_decls)

        # Transform bodies
        for fct in all_fcts:
            gen = self.circuits.get(fct, None)
            fct.body = ZkayStatementTransformer(gen).visit(fct.body)

        # Transform (internal) functions which require verification (add the necessary additional parameters and boilerplate code)
        fcts_with_verification = [fct for fct in all_fcts if fct.requires_verification]
        compute_transitive_circuit_io_sizes(fcts_with_verification, self.circuits)
        transform_internal_calls(fcts_with_verification, self.circuits)
        for f in fcts_with_verification:
            circuit = self.circuits[f]
            assert circuit.requires_verification()
            if circuit.requires_zk_data_struct():
                # Add zk data struct for f to contract
                zk_data_struct = StructDefinition(Identifier(circuit.zk_data_struct_name), [
                    VariableDeclaration([], AnnotatedTypeName(idf.t), idf.clone(), '')
                    for idf in circuit.output_idfs + circuit.input_idfs
                ])
                c.struct_definitions.append(zk_data_struct)
            self.create_internal_verification_wrapper(f)

        # Create external wrapper functions where necessary
        for f, params in req_ext_fcts.items():
            ext_f, int_f = self.split_into_external_and_internal_fct(f, params, global_owners)
            if ext_f.is_function:
                new_fcts.append(ext_f)
            else:
                new_constr.append(ext_f)
            new_fcts.append(int_f)

        c.constructor_definitions = new_constr
        c.function_definitions = new_fcts
        return c
Exemplo n.º 15
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition,
                                     ext_circuit: CircuitHelper,
                                     original_params: List[Parameter],
                                     requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        has_priv_args = any(
            [p.annotated_type.is_cipher() for p in original_params])
        stmts = []

        if has_priv_args:
            ext_circuit._require_public_key_for_label_at(
                None, Expression.me_expr())
        if cfg.is_symmetric_cipher():
            # Make sure msg.sender's key pair is available in the circuit
            assert any(isinstance(k, MeExpr) for k in ext_circuit.requested_global_keys) \
                   or has_priv_args, "requires verification => both sender keys required"
            stmts += ext_circuit.request_private_key()

        # Verify that out parameter has correct size
        stmts += [
            RequireStatement(
                IdentifierExpr(cfg.zk_out_name).dot('length').binop(
                    '==', NumberLiteralExpr(ext_circuit.out_size_trans)))
        ]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(
            Array(AnnotatedTypeName.uint_all()))

        # Find index of me's public key in requested_global_keys
        glob_me_key_index = -1
        for idx, e in enumerate(ext_circuit.requested_global_keys):
            if isinstance(e, MeExpr):
                glob_me_key_index = idx
                break

        # Request static public keys
        offset = 0
        key_req_stmts = []
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]
            if glob_me_key_index != -1:
                (keys[0], keys[glob_me_key_index]) = (keys[glob_me_key_index],
                                                      keys[0])

            tmp_key_var = Identifier('_tmp_key')
            key_req_stmts.append(
                tmp_key_var.decl_var(AnnotatedTypeName.key_type()))
            for key_owner in keys:
                idf, assignment = ext_circuit.request_public_key(
                    key_owner, ext_circuit.get_glob_key_name(key_owner))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Manually add to circuit inputs
                key_req_stmts.append(
                    in_arr_var.slice(offset, cfg.key_len).assign(
                        IdentifierExpr(tmp_key_var.clone()).slice(
                            0, cfg.key_len)))
                offset += cfg.key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                assign_stmt = in_arr_var.slice(
                    offset, cfg.cipher_payload_len).assign(
                        IdentifierExpr(p.idf.clone()).slice(
                            0, cfg.cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cfg.cipher_payload_len

        if cfg.is_symmetric_cipher():
            # Populate sender field of encrypted parameters
            copy_stmts = []
            for p in original_params:
                if p.annotated_type.is_cipher():
                    sender_key = in_arr_var.index(0)
                    idf = IdentifierExpr(p.idf.clone()).as_type(
                        p.annotated_type.clone())
                    lit = ArrayLiteralExpr([
                        idf.clone().index(i)
                        for i in range(cfg.cipher_payload_len)
                    ] + [sender_key])
                    copy_stmts.append(
                        VariableDeclarationStatement(
                            VariableDeclaration([], p.annotated_type.clone(),
                                                p.idf.clone(), 'memory'), lit))
            if copy_stmts:
                param_stmts += [
                    Comment(),
                    Comment(
                        'Copy from calldata to memory and set sender field')
                ] + copy_stmts

            assert glob_me_key_index != -1, "Symmetric cipher but did not request me key"

        # Declare in array
        new_in_array_expr = NewExpr(
            AnnotatedTypeName(TypeName.dyn_uint_array()),
            [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(),
                                              new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys',
                                            key_req_stmts)
        stmts += Comment.comment_wrap_block(
            'Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(
            IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [
                in_arr_var.clone(),
                NumberLiteralExpr(ext_circuit.in_size),
                IdentifierExpr(cfg.zk_out_name),
                NumberLiteralExpr(ext_circuit.out_size)
            ]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [
                VariableDeclarationStatement(deep_copy(vd))
                for vd in int_fct.return_var_decls
            ])
            in_call = TupleExpr([
                IdentifierExpr(vd.idf.clone())
                for vd in int_fct.return_var_decls
            ]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof:
            verifier = IdentifierExpr(
                cfg.get_contract_var_name(
                    ext_circuit.verifier_contract_type.code()))
            verifier_args = [
                IdentifierExpr(cfg.proof_param_name),
                IdentifierExpr(cfg.zk_in_name),
                IdentifierExpr(cfg.zk_out_name)
            ]
            verify = ExpressionStatement(
                verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(
                StatementList(
                    [Comment('Verify zk proof of execution'), verify],
                    excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(
                ReturnStatement(
                    TupleExpr([
                        IdentifierExpr(vd.idf.clone())
                        for vd in int_fct.return_var_decls
                    ])))

        return Block(stmts)
Exemplo n.º 16
0
 def visitIndexExpr(self, ast: IndexExpr):
     assert isinstance(ast.arr, LocationExpr), "Function call return value indexing not yet supported"
     source_t = ast.arr.target.annotated_type.type_name
     ast.target = VariableDeclaration([], source_t.value_type, Identifier(''))