示例#1
0
    def visitAnnotatedTypeName(self, ast: AnnotatedTypeName):
        if type(ast.type_name) == UserDefinedTypeName:
            if not isinstance(ast.type_name.target, EnumDefinition):
                raise TypeException('Unsupported use of user-defined type',
                                    ast.type_name)
            ast.type_name = ast.type_name.target.annotated_type.type_name.clone(
            )

        if ast.privacy_annotation != Expression.all_expr():
            if not ast.type_name.can_be_private():
                raise TypeException(
                    f'Currently, we do not support private {str(ast.type_name)}',
                    ast)

        p = ast.privacy_annotation
        if isinstance(p, IdentifierExpr):
            t = p.target
            if isinstance(t, Mapping):
                # no action necessary, this is the case: mapping(address!x => uint@x)
                pass
            elif not t.is_final and not t.is_constant:
                raise TypeException(
                    'Privacy annotations must be "final" or "constant", if they are expressions',
                    p)
            elif t.annotated_type != AnnotatedTypeName.address_all():
                raise TypeException(
                    f'Privacy type is not a public address, but {str(t.annotated_type)}',
                    p)
示例#2
0
    def split_into_external_and_internal_fct(self, f: ConstructorOrFunctionDefinition, original_params: List[Parameter],
                                             global_owners: List[PrivacyLabelExpr]) -> Tuple[ConstructorOrFunctionDefinition,
                                                                                             ConstructorOrFunctionDefinition]:
        """
        Take public function f and split it into an internal function and an external wrapper function.

        :param f: [SIDE EFFECT] function to split (at least requires_verification_if_external)
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :param global_owners: list of static labels (me + final address state variable identifiers)
        :return: Tuple of newly created external and internal function definitions
        """
        assert f.requires_verification_when_external

        # Create new empty function with same parameters as original -> external wrapper
        if f.is_function:
            new_modifiers = ['external']
            original_params = [deep_copy(p, with_types=True).with_changed_storage('memory', 'calldata') for p in original_params]
        else:
            new_modifiers = ['public']
        if f.is_payable:
            new_modifiers.append('payable')

        requires_proof = True
        if not f.has_side_effects:
            requires_proof = False
            new_modifiers.append('view')
        new_f = ConstructorOrFunctionDefinition(f.idf, original_params, new_modifiers, f.return_parameters, Block([]))

        # Make original function internal
        f.idf = Identifier(cfg.get_internal_name(f))
        f.modifiers = ['internal' if mod == 'public' else mod for mod in f.modifiers if mod != 'payable']
        f.requires_verification_when_external = False

        # Create new circuit for external function
        circuit = self.create_circuit_helper(new_f, global_owners, self.circuits[f])
        if not f.requires_verification:
            del self.circuits[f]
        self.circuits[new_f] = circuit

        # Set meta attributes and populate body
        new_f.requires_verification = True
        new_f.requires_verification_when_external = True
        new_f.called_functions = f.called_functions
        new_f.called_functions[f] = None
        new_f.used_crypto_backends = f.used_crypto_backends
        new_f.body = self.create_external_wrapper_body(f, circuit, original_params, requires_proof)

        # Add out and proof parameter to external wrapper
        storage_loc = 'calldata' if new_f.is_function else 'memory'
        new_f.add_param(Array(AnnotatedTypeName.uint_all()), Identifier(cfg.zk_out_name), storage_loc)

        if requires_proof:
            new_f.add_param(AnnotatedTypeName.proof_type(), Identifier(cfg.proof_param_name), storage_loc)

        return new_f, f
示例#3
0
 def handle_cast(self, expr: Expression, t: TypeName) -> AnnotatedTypeName:
     # because of the fake solidity check we already know that the cast is possible -> don't have to check if cast possible
     if expr.annotated_type.is_private():
         expected = AnnotatedTypeName(expr.annotated_type.type_name,
                                      Expression.me_expr())
         if not expr.instanceof(expected):
             raise TypeMismatchException(expected, expr.annotated_type,
                                         expr)
         return AnnotatedTypeName(t.clone(), Expression.me_expr())
     else:
         return AnnotatedTypeName(t.clone())
示例#4
0
    def visitFunctionCallExpr(self, ast: FunctionCallExpr):
        if isinstance(ast.func, BuiltinFunction):
            self.handle_builtin_function_call(ast, ast.func)
        elif ast.is_cast:
            if not isinstance(ast.func.target, EnumDefinition):
                raise NotImplementedError(
                    'User type casts only implemented for enums')
            ast.annotated_type = self.handle_cast(
                ast.args[0], ast.func.target.annotated_type.type_name)
        elif isinstance(ast.func, LocationExpr):
            ft = ast.func.annotated_type.type_name

            if len(ft.parameters) != len(ast.args):
                raise TypeException("Wrong number of arguments", ast.func)

            # Check arguments
            for i in range(len(ast.args)):
                ast.args[i] = self.get_rhs(ast.args[i],
                                           ft.parameters[i].annotated_type)

            # Set expression type to return type
            if len(ft.return_parameters) == 1:
                ast.annotated_type = ft.return_parameters[
                    0].annotated_type.clone()
            else:
                # TODO maybe not None label in the future
                ast.annotated_type = AnnotatedTypeName(
                    TupleType([t.annotated_type
                               for t in ft.return_parameters]), None)
        else:
            raise TypeException('Invalid function call', ast)
示例#5
0
    def create_contract_variable(cname: str) -> StateVariableDeclaration:
        """Create a public constant state variable with which contract with name 'cname' can be accessed"""
        inst_idf = Identifier(cfg.get_contract_var_name(cname))
        c_type = ContractTypeName([Identifier(cname)])

        cast_0_to_c = PrimitiveCastExpr(c_type, NumberLiteralExpr(0))
        var_decl = StateVariableDeclaration(AnnotatedTypeName(c_type), ['public', 'constant'], inst_idf.clone(), cast_0_to_c)
        return var_decl
示例#6
0
 def make_private_if_not_already(self, ast: Expression):
     if ast.annotated_type.is_private():
         expected = AnnotatedTypeName(ast.annotated_type.type_name,
                                      Expression.me_expr())
         if not ast.instanceof(expected):
             raise TypeMismatchException(expected, ast.annotated_type, ast)
         return ast
     else:
         return self.make_private(ast, Expression.me_expr())
示例#7
0
 def visitReturnStatement(self, ast: ReturnStatement):
     assert ast.function.is_function
     rt = AnnotatedTypeName(ast.function.return_type)
     if ast.expr is None:
         self.get_rhs(TupleExpr([]), rt)
     elif not isinstance(ast.expr, TupleExpr):
         ast.expr = self.get_rhs(TupleExpr([ast.expr]), rt)
     else:
         ast.expr = self.get_rhs(ast.expr, rt)
示例#8
0
    def make_rehom(expr: Expression, expected_type: AnnotatedTypeName):
        assert (expected_type.privacy_annotation.privacy_annotation_label()
                is not None)
        assert (expr.annotated_type.is_private_at_me(expr.analysis))
        assert (expected_type.is_private_at_me(expr.analysis))

        r = RehomExpr(expr, expected_type.homomorphism)

        # set type
        pl = get_privacy_expr_from_label(
            expected_type.privacy_annotation.privacy_annotation_label())
        r.annotated_type = AnnotatedTypeName(expr.annotated_type.type_name, pl,
                                             expected_type.homomorphism)
        TypeCheckVisitor.check_for_invalid_private_type(r)

        # set statement, parents, location
        TypeCheckVisitor.assign_location(r, expr)

        return r
示例#9
0
 def visitIfStatement(self, ast: IfStatement):
     b = ast.condition
     if not b.instanceof_data_type(TypeName.bool_type()):
         raise TypeMismatchException(TypeName.bool_type(),
                                     b.annotated_type.type_name, b)
     if ast.condition.annotated_type.is_private():
         expected = AnnotatedTypeName(TypeName.bool_type(),
                                      Expression.me_expr())
         if not b.instanceof(expected):
             raise TypeMismatchException(expected, b.annotated_type, b)
示例#10
0
    def visitAnnotatedTypeName(self, ast: AnnotatedTypeName):
        if type(ast.type_name) == UserDefinedTypeName:
            if not isinstance(ast.type_name.target, EnumDefinition):
                raise TypeException('Unsupported use of user-defined type',
                                    ast.type_name)
            ast.type_name = ast.type_name.target.annotated_type.type_name.clone(
            )

        if ast.privacy_annotation != Expression.all_expr():
            if not ast.type_name.can_be_private():
                raise TypeException(
                    f'Currently, we do not support private {str(ast.type_name)}',
                    ast)
            if ast.homomorphism != Homomorphism.NON_HOMOMORPHIC:
                # only support uint8, uint16, uint24, uint32 homomorphic data types
                if not ast.type_name.is_numeric:
                    raise TypeException(
                        f'Homomorphic type not supported for {str(ast.type_name)}: Only numeric types supported',
                        ast)
                elif ast.type_name.signed:
                    raise TypeException(
                        f'Homomorphic type not supported for {str(ast.type_name)}: Only unsigned types supported',
                        ast)
                elif ast.type_name.elem_bitwidth > 32:
                    raise TypeException(
                        f'Homomorphic type not supported for {str(ast.type_name)}: Only up to 32-bit numeric types supported',
                        ast)

        p = ast.privacy_annotation
        if isinstance(p, IdentifierExpr):
            t = p.target
            if isinstance(t, Mapping):
                # no action necessary, this is the case: mapping(address!x => uint@x)
                pass
            elif not t.is_final and not t.is_constant:
                raise TypeException(
                    'Privacy annotations must be "final" or "constant", if they are expressions',
                    p)
            elif t.annotated_type != AnnotatedTypeName.address_all():
                raise TypeException(
                    f'Privacy type is not a public address, but {str(t.annotated_type)}',
                    p)
示例#11
0
    def visitReclassifyExpr(self, ast: ReclassifyExpr):
        if not ast.privacy.privacy_annotation_label():
            raise TypeException(
                'Second argument of "reveal" cannot be used as a privacy type',
                ast)

        # NB prevent any redundant reveal (not just for public)
        ast.annotated_type = AnnotatedTypeName(
            ast.expr.annotated_type.type_name, ast.privacy)
        if ast.instanceof(ast.expr.annotated_type) is True:
            raise TypeException(
                f'Redundant "reveal": Expression is already "@{ast.privacy.code()}"',
                ast)
        self.check_for_invalid_private_type(ast)
示例#12
0
 def implicitly_converted_to(expr: Expression, t: TypeName) -> Expression:
     assert expr.annotated_type.type_name.is_primitive_type()
     cast = PrimitiveCastExpr(t.clone(), expr, is_implicit=True).override(
         parent=expr.parent,
         statement=expr.statement,
         line=expr.line,
         column=expr.column)
     cast.elem_type.parent = cast
     expr.parent = cast
     cast.annotated_type = AnnotatedTypeName(
         t.clone(),
         expr.annotated_type.privacy_annotation.clone()).override(
             parent=cast)
     return cast
示例#13
0
class GlobalVars:
    msg: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.all(
            StructTypeName([GlobalDefs.msg_struct.idf],
                           GlobalDefs.msg_struct)), [], Identifier('msg'),
        None)
    msg.idf.parent = msg

    block: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.all(
            StructTypeName([GlobalDefs.block_struct.idf],
                           GlobalDefs.block_struct)), [], Identifier('block'),
        None)
    block.idf.parent = block

    tx: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.all(
            StructTypeName([GlobalDefs.tx_struct.idf], GlobalDefs.tx_struct)),
        [], Identifier('tx'), None)
    tx.idf.parent = tx

    now: StateVariableDeclaration = StateVariableDeclaration(
        AnnotatedTypeName.uint_all(), [], Identifier('now'), None)
    now.idf.parent = now
示例#14
0
    def make_private(expr: Expression, privacy: Expression,
                     homomorphism: Homomorphism):
        assert (privacy.privacy_annotation_label() is not None)

        pl = get_privacy_expr_from_label(privacy.privacy_annotation_label())
        r = ReclassifyExpr(expr, pl, homomorphism)

        # set type
        r.annotated_type = AnnotatedTypeName(expr.annotated_type.type_name,
                                             pl.clone(), homomorphism)
        TypeCheckVisitor.check_for_invalid_private_type(r)

        # set statement, parents, location
        TypeCheckVisitor.assign_location(r, expr)

        return r
示例#15
0
    def implicitly_converted_to(expr: Expression, t: TypeName) -> Expression:
        if isinstance(expr, ReclassifyExpr) and not expr.privacy.is_all_expr():
            # Cast the argument of the ReclassifyExpr instead
            expr.expr = TypeCheckVisitor.implicitly_converted_to(expr.expr, t)
            expr.annotated_type.type_name = expr.expr.annotated_type.type_name
            return expr

        assert expr.annotated_type.type_name.is_primitive_type()
        cast = PrimitiveCastExpr(t.clone(), expr, is_implicit=True).override(
            parent=expr.parent,
            statement=expr.statement,
            line=expr.line,
            column=expr.column)
        cast.elem_type.parent = cast
        expr.parent = cast
        cast.annotated_type = AnnotatedTypeName(
            t.clone(), expr.annotated_type.privacy_annotation.clone(),
            expr.annotated_type.homomorphism).override(parent=cast)
        return cast
示例#16
0
    def visitIndexExpr(self, ast: IndexExpr):
        arr = ast.arr
        index = ast.key

        map_t = arr.annotated_type
        # should have already been checked
        assert (map_t.privacy_annotation.is_all_expr())

        # do actual type checking
        if isinstance(map_t.type_name, Mapping):
            key_type = map_t.type_name.key_type
            expected = AnnotatedTypeName(key_type, Expression.all_expr())
            instance = index.instanceof(expected)
            if not instance:
                raise TypeMismatchException(expected, index.annotated_type,
                                            ast)

            # record indexing information
            if map_t.type_name.key_label is not None:  # TODO modification correct?
                if index.privacy_annotation_label():
                    map_t.type_name.instantiated_key = index
                else:
                    raise TypeException(
                        f'Index cannot be used as a privacy type for array of type {map_t}',
                        ast)

            # determine value type
            ast.annotated_type = map_t.type_name.value_type

            if not self.is_accessible_by_invoker(ast):
                raise TypeException(
                    "Tried to read value which cannot be proven to be owned by the transaction invoker",
                    ast)
        elif isinstance(map_t.type_name, Array):
            if ast.key.annotated_type.is_private():
                raise TypeException('No private array index', ast)
            if not ast.key.instanceof_data_type(TypeName.number_type()):
                raise TypeException('Array index must be numeric', ast)
            ast.annotated_type = map_t.type_name.value_type
        else:
            raise TypeException('Indexing into non-mapping', ast)
示例#17
0
    def visitReclassifyExpr(self, ast: ReclassifyExpr):
        if not ast.privacy.privacy_annotation_label():
            raise TypeException(
                'Second argument of "reveal" cannot be used as a privacy type',
                ast)

        homomorphism = ast.homomorphism or ast.expr.annotated_type.homomorphism
        assert (homomorphism is not None)

        # Prevent ReclassifyExpr to all with homomorphic type
        if ast.privacy.is_all_expr(
        ) and homomorphism != Homomorphism.NON_HOMOMORPHIC:
            # If the target privacy is all, we infer a target homomorphism of NON_HOMOMORPHIC
            ast.homomorphism = homomorphism = Homomorphism.NON_HOMOMORPHIC

        # Make sure the first argument to reveal / rehom is public or private provably equal to @me
        is_expr_at_all = ast.expr.annotated_type.is_public()
        is_expr_at_me = ast.expr.annotated_type.is_private_at_me(ast.analysis)
        if not is_expr_at_all and not is_expr_at_me:
            raise TypeException(
                f'First argument of "{ast.func_name()}" must be accessible,'
                f'i.e. @all or provably equal to @me', ast)

        # Prevent unhom(public_value)
        if is_expr_at_all and isinstance(
                ast, RehomExpr
        ) and ast.homomorphism == Homomorphism.NON_HOMOMORPHIC:
            raise TypeException(
                f'Cannot use "{ast.homomorphism.rehom_expr_name}" on a public value',
                ast)

        # NB prevent any redundant reveal (not just for public)
        ast.annotated_type = AnnotatedTypeName(
            ast.expr.annotated_type.type_name, ast.privacy, homomorphism)
        if ast.instanceof(ast.expr.annotated_type) is True:
            raise TypeException(
                f'Redundant "{ast.func_name()}": Expression is already '
                f'"@{ast.privacy.code()}{homomorphism}"', ast)
        self.check_for_invalid_private_type(ast)
示例#18
0
    def get_rhs(self, rhs: Expression, expected_type: AnnotatedTypeName):
        if isinstance(rhs, TupleExpr):
            if not isinstance(rhs, TupleExpr) or not isinstance(
                    expected_type.type_name, TupleType) or len(
                        rhs.elements) != len(expected_type.type_name.types):
                raise TypeMismatchException(expected_type, rhs.annotated_type,
                                            rhs)
            exprs = [
                self.get_rhs(a, e)
                for e, a, in zip(expected_type.type_name.types, rhs.elements)
            ]
            return replace_expr(rhs, TupleExpr(exprs)).as_type(
                TupleType([e.annotated_type for e in exprs]))

        require_rehom = False
        instance = rhs.instanceof(expected_type)

        if not instance:
            require_rehom = True
            expected_matching_hom = expected_type.with_homomorphism(
                rhs.annotated_type.homomorphism)
            instance = rhs.instanceof(expected_matching_hom)

        if not instance:
            raise TypeMismatchException(expected_type, rhs.annotated_type, rhs)
        else:
            if rhs.annotated_type.type_name != expected_type.type_name:
                rhs = self.implicitly_converted_to(rhs,
                                                   expected_type.type_name)

            if instance == 'make-private':
                return self.make_private(rhs, expected_type.privacy_annotation,
                                         expected_type.homomorphism)
            elif require_rehom:
                return self.try_rehom(rhs, expected_type)
            else:
                return rhs
示例#19
0
    def make_private(expr: Expression, privacy: Expression):
        assert (privacy.privacy_annotation_label() is not None)

        pl = get_privacy_expr_from_label(privacy.privacy_annotation_label())
        r = ReclassifyExpr(expr, pl)

        # set type
        r.annotated_type = AnnotatedTypeName(expr.annotated_type.type_name,
                                             pl.clone())
        TypeCheckVisitor.check_for_invalid_private_type(r)

        # set statement
        r.statement = expr.statement

        # set parents
        r.parent = expr.parent
        r.annotated_type.parent = r
        expr.parent = r

        # set source location
        r.line = expr.line
        r.column = expr.column

        return r
示例#20
0
 def visitAnnotatedTypeName(self,
                            ast: AnnotatedTypeName) -> Set[Homomorphism]:
     return {ast.homomorphism} if ast.is_private() else set()
示例#21
0
class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
                Identifier('send'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'],
                [Parameter([], AnnotatedTypeName.bool_all(), Identifier(''))],
                Block([])),
            ConstructorOrFunctionDefinition(
                Identifier('transfer'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'], [], Block([])),
        ])
    address_payable_struct.members[1].can_be_private = False
    address_payable_struct.members[2].can_be_private = False
    set_parents(address_payable_struct)

    msg_struct: StructDefinition = StructDefinition(Identifier('<msg>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('sender')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('value')),
    ])
    set_parents(msg_struct)

    block_struct: StructDefinition = StructDefinition(Identifier('<block>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('coinbase')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('difficulty')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gaslimit')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('number')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('timestamp')),
    ])
    set_parents(block_struct)

    tx_struct: StructDefinition = StructDefinition(Identifier('<tx>'), [
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gasprice')),
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('origin')),
    ])
    set_parents(tx_struct)
示例#22
0
    def handle_function_body(self, ast: ConstructorOrFunctionDefinition):
        """
        Return offchain simulation python code for the body of function ast.

        In addition to what the original code does, the generated python code also:

        * checks that internal functions are not called externally
        * processes arguments (encryption, address wrapping for external calls),
        * introduces msg, block and tx objects as local variables (populated with current blockchain state)
        * serializes the public circuit outputs and the private circuit inputs, which are obtained during \
          simulation into int lists so that they can be passed to the proof generation
        * generates the NIZK proof (if needed)
        * calls/issues transaction with transformed arguments ((encrypted) original args, out array, proof)
          (or deploys the contract in case of the constructor)
        """
        preamble_str = ''
        if ast.is_external:
            preamble_str += f'assert {IS_EXTERNAL_CALL}\n'
        preamble_str += f'msg, block, tx = {api("get_special_variables")}()\n' \
                        f'now = block.timestamp\n'
        circuit = self.current_circ

        if circuit and circuit.sec_idfs:
            priv_struct = StructDefinition(None, [
                VariableDeclaration([], AnnotatedTypeName(sec_idf.t), sec_idf)
                for sec_idf in circuit.sec_idfs
            ])
            preamble_str += f'\n{PRIV_VALUES_NAME}: Dict[str, Any] = {self.get_default_value(StructTypeName([], priv_struct))}\n'

        all_params = ', '.join(
            [f'{self.visit(param.idf)}' for param in self.current_params])
        if ast.can_be_external:
            # Wrap address strings in AddressValue object for external calls
            address_params = [
                self.visit(param.idf) for param in self.current_params
                if param.annotated_type.zkay_type.is_address()
            ]
            if address_params:
                assign_addr_str = f"{', '.join(address_params)} = {', '.join([f'AddressValue({p})' for p in address_params])}"
                preamble_str += f'\n{self.do_if_external(ast, [assign_addr_str])}\n'

        if ast.can_be_external and circuit:
            # Encrypt parameters and add private circuit inputs (plain + randomness)
            enc_param_str = ''
            for arg in self.current_params:
                if arg.annotated_type.is_cipher():
                    assert isinstance(arg.annotated_type.type_name, CipherText)
                    cipher: CipherText = arg.annotated_type.type_name
                    pname = self.visit(arg.idf)
                    plain_val = pname
                    plain_t = cipher.plain_type.type_name
                    crypto_params = cipher.crypto_params
                    crypto_str = f'crypto_backend="{crypto_params.crypto_name}"'
                    if plain_t.is_signed_numeric and crypto_params.enc_signed_as_unsigned:
                        plain_val = self.handle_cast(
                            pname,
                            UintTypeName(f'uint{plain_t.elem_bitwidth}'))
                    enc_param_str += f'{self.get_priv_value(arg.idf.name)} = {plain_val}\n'
                    if crypto_params.is_symmetric_cipher():
                        my_pk = f'{api("get_my_pk")}("{crypto_params.crypto_name}")[0]'
                        enc_expr = f'{api("enc")}({self.get_priv_value(arg.idf.name)}, {crypto_str})'
                        enc_param_str += f'{pname} = CipherValue({enc_expr}[0][:-1] + ({my_pk}, ), {crypto_str})\n'
                    else:
                        enc_expr = f'{api("enc")}({self.get_priv_value(arg.idf.name)}, {crypto_str})'
                        enc_param_str += f'{pname}, {self.get_priv_value(f"{arg.idf.name}_R")} = {enc_expr}\n'

            enc_param_comment_str = '\n# Encrypt parameters' if enc_param_str else ''
            enc_param_str = enc_param_str[:-1] if enc_param_str else ''

            actual_params_assign_str = f"actual_params = [{all_params}]"

            out_var_decl_str = f'{cfg.zk_out_name}: List[int] = [0 for _ in range({circuit.out_size_trans})]'
            out_var_decl_str += f'\nactual_params.append({cfg.zk_out_name})'

            pre_body_code = self.do_if_external(ast, [
                enc_param_comment_str, enc_param_str, actual_params_assign_str,
                out_var_decl_str
            ])
        elif ast.can_be_external:
            pre_body_code = f'actual_params = [{all_params}]'
        else:
            pre_body_code = ''

        # Simulate public contract to compute in_values (state variable values are pulled from blockchain if necessary)
        # (out values are also computed when encountered, by locally evaluating and encrypting
        # the corresponding private expressions)
        body_str = self.visit(ast.body).strip()

        serialize_str = ''
        if circuit is not None:
            if circuit.output_idfs:
                out_elemwidths = ', '.join([
                    str(out.t.elem_bitwidth)
                    if out.t.is_primitive_type() else '0'
                    for out in circuit.output_idfs
                ])
                serialize_str += f'\n{cfg.zk_out_name}[{cfg.zk_out_name}_start_idx:{cfg.zk_out_name}_start_idx + {circuit.out_size}] = ' \
                                 f'{api("serialize_circuit_outputs")}(zk__data, [{out_elemwidths}])'
            if circuit.sec_idfs:
                sec_elemwidths = ', '.join([
                    str(sec.t.elem_bitwidth)
                    if sec.t.is_primitive_type() else '0'
                    for sec in circuit.sec_idfs
                ])
                serialize_str += f'\n{api("serialize_private_inputs")}({PRIV_VALUES_NAME}, [{sec_elemwidths}])'
        if serialize_str:
            serialize_str = f'\n# Serialize circuit outputs and/or secret circuit inputs\n' + serialize_str.lstrip(
            )

        body_code = '\n'.join(
            dedent(s) for s in [
                f'\n## BEGIN Simulate body',
                body_str,
                '## END Simulate body',
                serialize_str,
            ] if s) + '\n'

        # Add proof to actual argument list (when required)
        generate_proof_str = ''
        fname = f"'{ast.name}'"
        if ast.can_be_external and circuit and ast.has_side_effects:
            generate_proof_str += '\n'.join([
                '\n#Generate proof',
                f"proof = {api('gen_proof')}({fname}, {cfg.zk_in_name}, {cfg.zk_out_name})",
                'actual_params.append(proof)'
            ])

        should_encrypt = ", ".join([
            str(p.annotated_type.is_cipher())
            for p in self.current_f.parameters
        ])
        if ast.is_constructor:
            invoke_transact_str = f'''
            # Deploy contract
            {api("deploy")}(actual_params, [{should_encrypt}]{", wei_amount=wei_amount" if ast.is_payable else ""})
            '''
        elif ast.has_side_effects:
            invoke_transact_str = f'''
            # Invoke public transaction
            return {api("transact")}({fname}, actual_params, [{should_encrypt}]{", wei_amount=wei_amount" if ast.is_payable else ""})
            '''
        elif ast.return_parameters:
            constructors = []
            for retparam in ast.return_parameters:
                t = retparam.annotated_type.type_name
                if isinstance(t, CipherText):
                    constr = f'(True, "{t.crypto_params.crypto_name}", {self._get_type_constr(t.plain_type.type_name)})'
                else:
                    constr = f'(False, None, {self._get_type_constr(t)})'
                constructors.append(constr)
            constructors = f"[{', '.join(constructors)}]"

            invoke_transact_str = f'''
            # Call pure/view function and return value
            return {api('call')}({fname}, actual_params, {constructors})
            '''
        else:
            invoke_transact_str = ''

        post_body_code = self.do_if_external(ast, [
            generate_proof_str, invoke_transact_str
        ], [
            f'return {", ".join([f"{cfg.return_var_name}_{idx}" for idx in range(len(ast.return_parameters))])}'
            if ast.is_function and ast.requires_verification
            and ast.return_parameters else None
        ])

        code = '\n\n'.join(s.strip() for s in [
            f'assert not {IS_EXTERNAL_CALL}'
            if not ast.can_be_external else None,
            dedent(preamble_str), pre_body_code, body_code, post_body_code
        ] if s)

        func_ctx_params = []
        if circuit:
            func_ctx_params.append(str(circuit.priv_in_size_trans))
        if ast.is_payable:
            func_ctx_params.append('wei_amount=wei_amount')
        if ast.can_be_external:
            func_ctx_params.append(f'name={fname}')
            code = 'with time_measure("transaction_full", skip=not zk__is_ext):\n' + indent(
                code)
        code = f'with self._function_ctx({", ".join(func_ctx_params)}) as {IS_EXTERNAL_CALL}:\n' + indent(
            code)
        return code
示例#23
0
# BUILTIN SPECIAL TYPE DEFINITIONS
from zkay.zkay_ast.ast import AnnotatedTypeName, FunctionTypeName, Parameter, Identifier, StructDefinition, \
    VariableDeclaration, TypeName, StateVariableDeclaration, UserDefinedTypeName, StructTypeName, Block, ConstructorOrFunctionDefinition
from zkay.zkay_ast.pointers.parent_setter import set_parents

array_length_member = VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                          Identifier('length'))


class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
示例#24
0
    def evaluate_stmt_in_circuit(self, ast: Statement) -> AssignmentStatement:
        """
        Evaluate an entire statement privately.

        This works by turning the statement into an assignment statement where the

        * lhs is a tuple of all external locations (defined outside statement), which are modified inside the statement
        * rhs is the return value of an inlined function call expression to a virtual function where body = the statement + return statement \
          which returns a tuple of the most recent SSA version of all modified locations

        Note: Modifying external locations which are not owned by @me inside the statement is illegal (would leak information).
        Note: At the moment, this is only used for if statements with a private condition.

        :param ast: the statement to evaluate inside the circuit
        :return: AssignmentStatement as described above
        """
        astmt = ExpressionStatement(NumberLiteralExpr(0))
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                astmt = AssignmentStatement(None, None)
                break

        astmt.before_analysis = ast.before_analysis

        # External values written inside statement -> function return values
        ret_params = []
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                # side effect affects location outside statement and has privacy @me
                assert ast.before_analysis.same_partition(
                    var.privacy, Expression.me_expr())
                assert isinstance(
                    var.target,
                    (Parameter, VariableDeclaration, StateVariableDeclaration))
                t = var.target.annotated_type.zkay_type
                if not t.type_name.is_primitive_type():
                    raise NotImplementedError(
                        'Reference types inside private if statements are not supported'
                    )
                ret_t = AnnotatedTypeName(t.type_name, Expression.me_expr(),
                                          t.homomorphism)  # t, but @me
                ret_param = IdentifierExpr(var.target.idf.clone(),
                                           ret_t).override(target=var.target)
                ret_param.statement = astmt
                ret_params.append(ret_param)

        # Build the imaginary function
        fdef = ConstructorOrFunctionDefinition(
            Identifier('<stmt_fct>'), [], ['private'], [
                Parameter([], ret.annotated_type, ret.target.idf)
                for ret in ret_params
            ], Block([ast, ReturnStatement(TupleExpr(ret_params))]))
        fdef.original_body = fdef.body
        fdef.body.parent = fdef
        fdef.parent = ast

        # inline "Call" to the imaginary function
        fcall = FunctionCallExpr(
            IdentifierExpr('<stmt_fct>').override(target=fdef), [])
        fcall.statement = astmt
        ret_args = self.inline_function_call_into_circuit(fcall)

        # Move all return values out of the circuit
        if not isinstance(ret_args, TupleExpr):
            ret_args = TupleExpr([ret_args])
        for ret_arg in ret_args.elements:
            ret_arg.statement = astmt
        ret_arg_outs = [
            self._get_circuit_output_for_private_expression(
                ret_arg, Expression.me_expr(),
                ret_param.annotated_type.homomorphism)
            for ret_param, ret_arg in zip(ret_params, ret_args.elements)
        ]

        # Create assignment statement
        if ret_params:
            astmt.lhs = TupleExpr(
                [ret_param.clone() for ret_param in ret_params])
            astmt.rhs = TupleExpr(ret_arg_outs)
            return astmt
        else:
            assert isinstance(astmt, ExpressionStatement)
            return astmt
示例#25
0
    def _get_circuit_output_for_private_expression(
            self, expr: Expression, new_privacy: PrivacyLabelExpr,
            homomorphism: Homomorphism) -> LocationExpr:
        """
        Add evaluation of expr to the circuit and return the output HybridArgumentIdf corresponding to the evaluation result.

        Note: has side effects on expr.statement (adds pre_statement)

        :param expr: [SIDE EFFECT] expression to evaluate
        :param new_privacy: result owner (determines encryption key)
        :return: HybridArgumentIdf which references the circuit output containing the result of expr
        """
        is_circ_val = isinstance(expr, IdentifierExpr) and isinstance(
            expr.idf, HybridArgumentIdf
        ) and expr.idf.arg_type != HybridArgType.PUB_CONTRACT_VAL
        is_hom_comp = isinstance(expr, FunctionCallExpr) and isinstance(
            expr.func, BuiltinFunction
        ) and expr.func.homomorphism != Homomorphism.NON_HOMOMORPHIC
        if is_hom_comp:
            # Treat a homomorphic operation as a privately evaluated operation on (public) ciphertexts
            expr.annotated_type = AnnotatedTypeName.cipher_type(
                expr.annotated_type, homomorphism)

        if is_circ_val or expr.annotated_type.is_private(
        ) or expr.evaluate_privately:
            priv_result_idf = self._evaluate_private_expression(expr)
        else:
            # For public expressions which should not be evaluated in private, only the result is moved into the circuit
            priv_result_idf = self.add_to_circuit_inputs(expr)
        private_expr = priv_result_idf.get_idf_expr()

        t_suffix = ''
        if isinstance(expr, IdentifierExpr) and not is_circ_val:
            t_suffix += f'_{expr.idf.name}'

        if isinstance(new_privacy,
                      AllExpr) or expr.annotated_type.type_name.is_cipher():
            # If the result is public, add an equality constraint to ensure that the user supplied public output
            # is equal to the circuit evaluation result
            tname = f'{self._out_name_factory.get_new_name(expr.annotated_type.type_name)}{t_suffix}'
            new_out_param = self._out_name_factory.add_idf(
                tname, expr.annotated_type.type_name, private_expr)
            self._phi.append(CircEqConstraint(priv_result_idf, new_out_param))
            out_var = new_out_param.get_loc_expr().explicitly_converted(
                expr.annotated_type.type_name)
        else:
            # If the result is encrypted, add an encryption constraint to ensure that the user supplied encrypted output
            # is equal to the correctly encrypted circuit evaluation result
            new_privacy = self._get_canonical_privacy_label(
                expr.analysis, new_privacy)
            privacy_label_expr = get_privacy_expr_from_label(new_privacy)
            cipher_t = TypeName.cipher_type(expr.annotated_type, homomorphism)
            tname = f'{self._out_name_factory.get_new_name(cipher_t)}{t_suffix}'
            enc_expr = EncryptionExpression(private_expr, privacy_label_expr,
                                            homomorphism)
            new_out_param = self._out_name_factory.add_idf(
                tname, cipher_t, enc_expr)
            crypto_params = cfg.get_crypto_params(homomorphism)
            self._ensure_encryption(expr.statement, priv_result_idf,
                                    new_privacy, crypto_params, new_out_param,
                                    False, False)
            out_var = new_out_param.get_loc_expr()

        # Add an invisible CircuitComputationStatement to the solidity code, which signals the offchain simulator,
        # that the value the contained out variable must be computed at this point by simulating expression evaluation
        expr.statement.pre_statements.append(
            CircuitComputationStatement(new_out_param))
        return out_var
示例#26
0
 def visitAnnotatedTypeName(self, ast: AnnotatedTypeName):
     if ast.is_private():
         t = TypeName.cipher_type(ast)
     else:
         t = self.visit(ast.type_name.clone())
     return AnnotatedTypeName(t)
示例#27
0
 def visitMeExpr(ast: MeExpr):
     """Replace me with msg.sender."""
     return replace_expr(ast, IdentifierExpr('msg').dot('sender')).as_type(AnnotatedTypeName.address_all())
示例#28
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition, ext_circuit: CircuitHelper,
                                     original_params: List[Parameter], requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        priv_args = [p for p in original_params if p.annotated_type.is_cipher()]
        args_backends = OrderedDict.fromkeys([p.annotated_type.type_name.crypto_params for p in priv_args])
        stmts = []

        for crypto_params in args_backends:
            assert crypto_params in int_fct.used_crypto_backends
            # If there are any private arguments with homomorphism 'hom', we need the public key for that crypto backend
            ext_circuit._require_public_key_for_label_at(None, Expression.me_expr(), crypto_params)
        for crypto_params in cfg.all_crypto_params():
            if crypto_params.is_symmetric_cipher():
                if (MeExpr(), crypto_params) in ext_circuit.requested_global_keys or crypto_params in args_backends:
                    # Make sure msg.sender's key pair is available in the circuit
                    stmts += ext_circuit.request_private_key(crypto_params)

        # Verify that out parameter has correct size
        stmts += [RequireStatement(IdentifierExpr(cfg.zk_out_name).dot('length').binop('==', NumberLiteralExpr(ext_circuit.out_size_trans)))]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))

        # Request static public keys
        offset = 0
        key_req_stmts = []
        me_key_idx: Dict[CryptoParams, int] = {}
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]

            tmp_keys = {}
            for crypto_params in int_fct.used_crypto_backends:
                tmp_key_var = Identifier(f'_tmp_key_{crypto_params.identifier_name}')
                key_req_stmts.append(tmp_key_var.decl_var(AnnotatedTypeName.key_type(crypto_params)))
                tmp_keys[crypto_params] = tmp_key_var
            for (key_owner, crypto_params) in keys:
                tmp_key_var = tmp_keys[crypto_params]
                idf, assignment = ext_circuit.request_public_key(crypto_params, key_owner, ext_circuit.get_glob_key_name(key_owner, crypto_params))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Remember me-keys for later use in symmetrically encrypted keys
                if key_owner == MeExpr():
                    assert crypto_params not in me_key_idx
                    me_key_idx[crypto_params] = offset

                # Manually add to circuit inputs
                key_len = crypto_params.key_len
                key_req_stmts.append(in_arr_var.slice(offset, key_len).assign(IdentifierExpr(tmp_key_var.clone()).slice(0, key_len)))
                offset += key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                cipher_payload_len = p.annotated_type.type_name.crypto_params.cipher_payload_len
                assign_stmt = in_arr_var.slice(offset, cipher_payload_len).assign(IdentifierExpr(p.idf.clone()).slice(0, cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cipher_payload_len

        # Populate sender field of parameters encrypted with a symmetric cipher
        copy_stmts = []
        for p in original_params:
            if p.annotated_type.is_cipher():
                c = p.annotated_type.type_name
                assert isinstance(c, CipherText)
                if c.crypto_params.is_symmetric_cipher():
                    sender_key = in_arr_var.index(me_key_idx[c.crypto_params])
                    idf = IdentifierExpr(p.idf.clone()).as_type(p.annotated_type.clone())
                    cipher_payload_len = cfg.get_crypto_params(p.annotated_type.homomorphism).cipher_payload_len
                    lit = ArrayLiteralExpr([idf.clone().index(i) for i in range(cipher_payload_len)] + [sender_key])
                    copy_stmts.append(VariableDeclarationStatement(VariableDeclaration([], p.annotated_type.clone(), p.idf.clone(), 'memory'), lit))
        if copy_stmts:
            param_stmts += [Comment(), Comment('Copy from calldata to memory and set sender field')] + copy_stmts

        # Declare in array
        new_in_array_expr = NewExpr(AnnotatedTypeName(TypeName.dyn_uint_array()), [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(), new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys', key_req_stmts)
        stmts += Comment.comment_wrap_block('Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [in_arr_var.clone(), NumberLiteralExpr(ext_circuit.in_size),
                     IdentifierExpr(cfg.zk_out_name), NumberLiteralExpr(ext_circuit.out_size)]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(deep_copy(vd)) for vd in int_fct.return_var_decls])
            in_call = TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof and not cfg.disable_verification:
            verifier = IdentifierExpr(cfg.get_contract_var_name(ext_circuit.verifier_contract_type.code()))
            verifier_args = [IdentifierExpr(cfg.proof_param_name), IdentifierExpr(cfg.zk_in_name), IdentifierExpr(cfg.zk_out_name)]
            verify = ExpressionStatement(verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(StatementList([Comment('Verify zk proof of execution'), verify], excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls])))

        return Block(stmts)
示例#29
0
    def transform_contract(self, su: SourceUnit, c: ContractDefinition) -> ContractDefinition:
        """
        Transform an entire zkay contract into a public solidity contract.

        This:

        * transforms state variables, function bodies and signatures
        * import verification contracts
        * adds zk_data structs for each function with verification \
          (to store circuit I/O, to bypass solidity stack limit and allow for easy assignment of array variables),
        * creates external wrapper functions for all public functions which require verification
        * adds circuit IO serialization/deserialization code from/to zk_data struct to all functions which require verification.

        :param su: [SIDE EFFECTS] Source unit of which this contract is part of
        :param c: [SIDE EFFECTS] The contract to transform
        :return: The contract itself
        """

        all_fcts = c.constructor_definitions + c.function_definitions

        # Get list of static owner labels for this contract
        global_owners = [Expression.me_expr()]
        for var in c.state_variable_declarations:
            if var.annotated_type.is_address() and (var.is_final or var.is_constant):
                global_owners.append(var.idf)

        # Backup untransformed function bodies
        for fct in all_fcts:
            fct.original_body = deep_copy(fct.body, with_types=True, with_analysis=True)

        # Transform types of normal state variables
        c.state_variable_declarations = self.var_decl_trafo.visit_list(c.state_variable_declarations)

        # Split into functions which require verification and those which don't need a circuit helper
        req_ext_fcts = {}
        new_fcts, new_constr = [], []
        for fct in all_fcts:
            assert isinstance(fct, ConstructorOrFunctionDefinition)
            if fct.requires_verification or fct.requires_verification_when_external:
                self.circuits[fct] = self.create_circuit_helper(fct, global_owners)

            if fct.requires_verification_when_external:
                req_ext_fcts[fct] = fct.parameters[:]
            elif fct.is_constructor:
                new_constr.append(fct)
            else:
                new_fcts.append(fct)

        # Add constant state variables for external contracts and field prime
        field_prime_decl = StateVariableDeclaration(AnnotatedTypeName.uint_all(), ['public', 'constant'],
                                                    Identifier(cfg.field_prime_var_name),
                                                    NumberLiteralExpr(bn128_scalar_field))
        contract_var_decls = self.include_verification_contracts(su, c)
        c.state_variable_declarations = [field_prime_decl, Comment()]\
                                        + Comment.comment_list('Helper Contracts', contract_var_decls)\
                                        + [Comment('User state variables')]\
                                        + c.state_variable_declarations

        # Transform signatures
        for f in all_fcts:
            f.parameters = self.var_decl_trafo.visit_list(f.parameters)
        for f in c.function_definitions:
            f.return_parameters = self.var_decl_trafo.visit_list(f.return_parameters)
            f.return_var_decls = self.var_decl_trafo.visit_list(f.return_var_decls)

        # Transform bodies
        for fct in all_fcts:
            gen = self.circuits.get(fct, None)
            fct.body = ZkayStatementTransformer(gen).visit(fct.body)

        # Transform (internal) functions which require verification (add the necessary additional parameters and boilerplate code)
        fcts_with_verification = [fct for fct in all_fcts if fct.requires_verification]
        compute_transitive_circuit_io_sizes(fcts_with_verification, self.circuits)
        transform_internal_calls(fcts_with_verification, self.circuits)
        for f in fcts_with_verification:
            circuit = self.circuits[f]
            assert circuit.requires_verification()
            if circuit.requires_zk_data_struct():
                # Add zk data struct for f to contract
                zk_data_struct = StructDefinition(Identifier(circuit.zk_data_struct_name), [
                    VariableDeclaration([], AnnotatedTypeName(idf.t), idf.clone(), '')
                    for idf in circuit.output_idfs + circuit.input_idfs
                ])
                c.struct_definitions.append(zk_data_struct)
            self.create_internal_verification_wrapper(f)

        # Create external wrapper functions where necessary
        for f, params in req_ext_fcts.items():
            ext_f, int_f = self.split_into_external_and_internal_fct(f, params, global_owners)
            if ext_f.is_function:
                new_fcts.append(ext_f)
            else:
                new_constr.append(ext_f)
            new_fcts.append(int_f)

        c.constructor_definitions = new_constr
        c.function_definitions = new_fcts
        return c
示例#30
0
    def create_internal_verification_wrapper(self, ast: ConstructorOrFunctionDefinition):
        """
        Add the necessary additional parameters and boiler plate code for verification support to the given function.

        :param ast: [SIDE EFFECT] Internal function which requires verification
        """
        circuit = self.circuits[ast]
        stmts = []

        symmetric_cipher_used = any([backend.is_symmetric_cipher() for backend in ast.used_crypto_backends])
        if symmetric_cipher_used and 'pure' in ast.modifiers:
            # Symmetric trafo requires msg.sender access -> change from pure to view
            ast.modifiers = ['view' if mod == 'pure' else mod for mod in ast.modifiers]

        # Add additional params
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_in_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_in_name}_start_idx')
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_out_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_out_name}_start_idx')

        # Verify that in/out parameters have correct size
        out_start_idx, in_start_idx = IdentifierExpr(f'{cfg.zk_out_name}_start_idx'), IdentifierExpr(f'{cfg.zk_in_name}_start_idx')
        out_var, in_var = IdentifierExpr(cfg.zk_out_name), IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))
        stmts.append(RequireStatement(out_start_idx.binop('+', NumberLiteralExpr(circuit.out_size_trans)).binop('<=', out_var.dot('length'))))
        stmts.append(RequireStatement(in_start_idx.binop('+', NumberLiteralExpr(circuit.in_size_trans)).binop('<=', in_var.dot('length'))))

        # Declare zk_data struct var (if needed)
        if circuit.requires_zk_data_struct():
            zk_struct_type = StructTypeName([Identifier(circuit.zk_data_struct_name)])
            stmts += [Identifier(cfg.zk_data_var_name).decl_var(zk_struct_type), BlankLine()]

        # Declare return variable if necessary
        if ast.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(vd) for vd in ast.return_var_decls])

        # Find all me-keys in the in array
        me_key_idx: Dict[CryptoParams, int] = {}
        offset = 0
        for (key_owner, crypto_params) in circuit.requested_global_keys:
            if key_owner == MeExpr():
                assert crypto_params not in me_key_idx
                me_key_idx[crypto_params] = offset
            offset += crypto_params.key_len

        # Deserialize out array (if any)
        deserialize_stmts = []
        offset = 0
        for s in circuit.output_idfs:
            deserialize_stmts.append(s.deserialize(cfg.zk_out_name, out_start_idx, offset))
            if isinstance(s.t, CipherText) and s.t.crypto_params.is_symmetric_cipher():
                # Assign sender field to user-encrypted values if necessary
                # Assumption: s.t.crypto_params.key_len == 1 for all symmetric ciphers
                assert s.t.crypto_params in me_key_idx, "Symmetric cipher but did not request me key"
                key_idx = me_key_idx[s.t.crypto_params]
                sender_key = in_var.index(key_idx)
                cipher_payload_len = s.t.crypto_params.cipher_payload_len
                deserialize_stmts.append(s.get_loc_expr().index(cipher_payload_len).assign(sender_key))
            offset += s.t.size_in_uints
        if deserialize_stmts:
            stmts.append(StatementList(Comment.comment_wrap_block("Deserialize output values", deserialize_stmts), excluded_from_simulation=True))

        # Include original transformed function body
        stmts += ast.body.statements

        # Serialize in parameters to in array (if any)
        serialize_stmts = []
        offset = 0
        for s in circuit.input_idfs:
            serialize_stmts += [s.serialize(cfg.zk_in_name, in_start_idx, offset)]
            offset += s.t.size_in_uints
        if offset:
            stmts.append(Comment())
            stmts += Comment.comment_wrap_block('Serialize input values', serialize_stmts)

        # Add return statement at the end if necessary
        # (was previously replaced by assignment to return_var by ZkayStatementTransformer)
        if circuit.has_return_var:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()).override(target=vd) for vd in ast.return_var_decls])))

        ast.body.statements[:] = stmts