示例#1
0
    def visitFunctionCallExpr(self, ast: FunctionCallExpr):
        if isinstance(ast.func, BuiltinFunction):
            self.handle_builtin_function_call(ast, ast.func)
        elif ast.is_cast:
            if not isinstance(ast.func.target, EnumDefinition):
                raise NotImplementedError(
                    'User type casts only implemented for enums')
            ast.annotated_type = self.handle_cast(
                ast.args[0], ast.func.target.annotated_type.type_name)
        elif isinstance(ast.func, LocationExpr):
            ft = ast.func.annotated_type.type_name

            if len(ft.parameters) != len(ast.args):
                raise TypeException("Wrong number of arguments", ast.func)

            # Check arguments
            for i in range(len(ast.args)):
                ast.args[i] = self.get_rhs(ast.args[i],
                                           ft.parameters[i].annotated_type)

            # Set expression type to return type
            if len(ft.return_parameters) == 1:
                ast.annotated_type = ft.return_parameters[
                    0].annotated_type.clone()
            else:
                # TODO maybe not None label in the future
                ast.annotated_type = AnnotatedTypeName(
                    TupleType([t.annotated_type
                               for t in ft.return_parameters]), None)
        else:
            raise TypeException('Invalid function call', ast)
示例#2
0
    def handle_homomorphic_builtin_function_call(self, ast: FunctionCallExpr,
                                                 func: BuiltinFunction):
        # First - same as non-homomorphic - check that argument types conform to op signature
        if not func.is_eq():
            for arg, t in zip(ast.args, func.input_types()):
                if not arg.instanceof_data_type(t):
                    raise TypeMismatchException(t,
                                                arg.annotated_type.type_name,
                                                arg)

        homomorphic_func = func.select_homomorphic_overload(
            ast.args, ast.analysis)
        if homomorphic_func is None:
            raise TypeException(
                f'Operation \'{func.op}\' requires all arguments to be accessible, '
                f'i.e. @all or provably equal to @me', ast)

        # We could perform homomorphic operations on-chain by using some Solidity arbitrary precision math library.
        # For now, keep it simple and evaluate homomorphic operations in Python and check the result in the circuit.
        func.is_private = True

        ast.annotated_type = homomorphic_func.output_type()
        func.homomorphism = ast.annotated_type.homomorphism
        expected_arg_types = homomorphic_func.input_types()

        # Check that the argument types are correct
        ast.args[:] = map(lambda arg, arg_pt: self.get_rhs(arg, arg_pt),
                          ast.args, expected_arg_types)
示例#3
0
文件: build_ast.py 项目: nibau/zkay
 def visitAssignmentExpr(self, ctx: SolidityParser.AssignmentExprContext):
     if not self.is_expr_stmt(ctx):
         raise SyntaxException('Assignments are only allowed as statements', ctx, self.code)
     lhs = self.visit(ctx.lhs)
     rhs = self.visit(ctx.rhs)
     assert ctx.op.text[-1] == '='
     op = ctx.op.text[:-1] if ctx.op.text != '=' else ''
     if op:
         # If the assignment contains an additional operator -> replace lhs = rhs with lhs = lhs 'op' rhs
         rhs = FunctionCallExpr(BuiltinFunction(op).override(line=ctx.op.line, column=ctx.op.column), [self.visit(ctx.lhs), rhs])
         rhs.line = ctx.rhs.start.line
         rhs.column = ctx.rhs.start.column + 1
     return ast.AssignmentStatement(lhs, rhs, op)
示例#4
0
文件: build_ast.py 项目: nibau/zkay
    def _handle_crement_expr(self, ctx, kind: str):
        if not self.is_expr_stmt(ctx):
            raise SyntaxException(f'{kind}-crement expressions are only allowed as statements', ctx, self.code)
        op = '+' if ctx.op.text == '++' else '-'

        one = NumberLiteralExpr(1)
        one.line = ctx.op.line
        one.column = ctx.op.column + 1

        fct = FunctionCallExpr(BuiltinFunction(op).override(line=ctx.op.line, column=ctx.op.column), [self.visit(ctx.expr), one])
        fct.line = ctx.op.line
        fct.column = ctx.op.column + 1

        return ast.AssignmentStatement(self.visit(ctx.expr), fct, f'{kind}{ctx.op.text}')
示例#5
0
文件: build_ast.py 项目: nibau/zkay
    def visitFunctionCallExpr(self, ctx: SolidityParser.FunctionCallExprContext):
        func = self.visit(ctx.func)
        args = self.handle_field(ctx.args)

        if isinstance(func, IdentifierExpr):
            if func.idf.name == 'reveal':
                if len(args) != 2:
                    raise SyntaxException(f'Invalid number of arguments for reveal: {args}', ctx.args, self.code)
                return ReclassifyExpr(args[0], args[1])

        return FunctionCallExpr(func, args)
示例#6
0
    def handle_builtin_function_call(self, ast: FunctionCallExpr,
                                     func: BuiltinFunction):
        if func.is_parenthesis():
            ast.annotated_type = ast.args[0].annotated_type
            return

        all_args_all_or_me = all(
            map(lambda x: x.annotated_type.is_accessible(ast.analysis),
                ast.args))
        is_public_ite = func.is_ite() and ast.args[0].annotated_type.is_public(
        )
        if all_args_all_or_me or is_public_ite:
            self.handle_unhom_builtin_function_call(ast, func)
        else:
            self.handle_homomorphic_builtin_function_call(ast, func)
示例#7
0
文件: build_ast.py 项目: eth-sri/zkay
    def visitFunctionCallExpr(self, ctx: SolidityParser.FunctionCallExprContext):
        func = self.visit(ctx.func)
        args = self.handle_field(ctx.args)

        if isinstance(func, IdentifierExpr):
            if func.idf.name == 'reveal':
                if len(args) != 2:
                    raise SyntaxException(f'Invalid number of arguments for reveal: {args}', ctx.args, self.code)
                return ReclassifyExpr(args[0], args[1], None)
            elif func.idf.name in self.rehom_expressions:
                name = func.idf.name
                homomorphism = self.rehom_expressions[name]
                if len(args) != 1:
                    raise SyntaxException(f'Invalid number of arguments for {name}: {args}', ctx.args, self.code)
                return RehomExpr(args[0], homomorphism)

        return FunctionCallExpr(func, args)
示例#8
0
    def visitFunctionCallExpr(self, ast: FunctionCallExpr):
        self.visitAST(ast)
        if isinstance(ast.func, LocationExpr):
            # for now no reference types -> only state could have been modified
            fdef = ast.func.target
            rlen = len(ast.read_values)
            ast.read_values.update({
                v
                for v in fdef.read_values
                if isinstance(v.target, StateVariableDeclaration)
            })
            self.fixed_point_reached &= rlen == len(ast.read_values)

            # update modified values if any
            mlen = len(ast.modified_values)
            for v in fdef.modified_values:
                if isinstance(v.target, StateVariableDeclaration):
                    ast.modified_values[v] = None
            self.fixed_point_reached &= mlen == len(ast.modified_values)
示例#9
0
    def handle_builtin_function_call(self, ast: FunctionCallExpr,
                                     func: BuiltinFunction):
        # handle special cases
        if func.is_ite():
            cond_t = ast.args[0].annotated_type

            # Ensure that condition is boolean
            if not cond_t.type_name.implicitly_convertible_to(
                    TypeName.bool_type()):
                raise TypeMismatchException(TypeName.bool_type(),
                                            cond_t.type_name, ast.args[0])

            res_t = ast.args[1].annotated_type.type_name.combined_type(
                ast.args[2].annotated_type.type_name, True)

            if cond_t.is_private():
                # Everything is turned private
                func.is_private = True
                a = res_t.annotate(Expression.me_expr())
            else:
                p = ast.args[1].annotated_type.combined_privacy(
                    ast.analysis, ast.args[2].annotated_type)
                a = res_t.annotate(p)
            ast.args[1] = self.get_rhs(ast.args[1], a)
            ast.args[2] = self.get_rhs(ast.args[2], a)

            ast.annotated_type = a
            return
        elif func.is_parenthesis():
            ast.annotated_type = ast.args[0].annotated_type
            return

        # Check that argument types conform to op signature
        parameter_types = func.input_types()
        if not func.is_eq():
            for arg, t in zip(ast.args, parameter_types):
                if not arg.instanceof_data_type(t):
                    raise TypeMismatchException(t,
                                                arg.annotated_type.type_name,
                                                arg)

        t1 = ast.args[0].annotated_type.type_name
        t2 = None if len(
            ast.args) == 1 else ast.args[1].annotated_type.type_name

        if len(ast.args) == 1:
            arg_t = 'lit' if ast.args[
                0].annotated_type.type_name.is_literal else t1
        else:
            assert len(ast.args) == 2
            is_eq_with_tuples = func.is_eq() and isinstance(t1, TupleType)
            arg_t = t1.combined_type(t2, convert_literals=is_eq_with_tuples)

        # Infer argument and output types
        if arg_t == 'lit':
            res = func.op_func(
                *[arg.annotated_type.type_name.value for arg in ast.args])
            if isinstance(res, bool):
                assert func.output_type() == TypeName.bool_type()
                out_t = BooleanLiteralType(res)
            else:
                assert func.output_type() == TypeName.number_type()
                out_t = NumberLiteralType(res)
            if func.is_eq():
                arg_t = t1.to_abstract_type().combined_type(
                    t2.to_abstract_type(), True)
        elif func.output_type() == TypeName.bool_type():
            out_t = TypeName.bool_type()
        else:
            out_t = arg_t

        assert arg_t is not None and (arg_t != 'lit' or not func.is_eq())

        private_args = any(map(self.has_private_type, ast.args))
        if private_args:
            assert arg_t != 'lit'
            if func.can_be_private():
                if func.is_shiftop():
                    if not ast.args[1].annotated_type.type_name.is_literal:
                        raise TypeException(
                            'Private shift expressions must use a constant (literal) shift amount',
                            ast.args[1])
                    if ast.args[1].annotated_type.type_name.value < 0:
                        raise TypeException('Cannot shift by negative amount',
                                            ast.args[1])
                if func.is_bitop() or func.is_shiftop():
                    for arg in ast.args:
                        if arg.annotated_type.type_name.elem_bitwidth == 256:
                            raise TypeException(
                                'Private bitwise and shift operations are only supported for integer types < 256 bit, '
                                'please use a smaller type', arg)

                if func.is_arithmetic():
                    for a in ast.args:
                        if a.annotated_type.type_name.elem_bitwidth == 256:
                            issue_compiler_warning(
                                func, 'Possible field prime overflow',
                                'Private arithmetic 256bit operations overflow at FIELD_PRIME.\n'
                                'If you need correct overflow behavior, use a smaller integer type.'
                            )
                            break
                elif func.is_comp():
                    for a in ast.args:
                        if a.annotated_type.type_name.elem_bitwidth == 256:
                            issue_compiler_warning(
                                func, 'Possible private comparison failure',
                                'Private 256bit comparison operations will fail for values >= 2^252.\n'
                                'If you cannot guarantee that the value stays in range, you must use '
                                'a smaller integer type to ensure correctness.'
                            )
                            break

                func.is_private = True
                p = Expression.me_expr()
            else:
                raise TypeException(
                    f'Operation \'{func.op}\' does not support private operands',
                    ast)
        else:
            p = None

        if arg_t != 'lit':
            # Add implicit casts for arguments
            arg_pt = arg_t.annotate(p)
            if func.is_shiftop() and p is not None:
                ast.args[0] = self.get_rhs(ast.args[0], arg_pt)
            else:
                ast.args[:] = map(
                    lambda argument: self.get_rhs(argument, arg_pt), ast.args)

        ast.annotated_type = out_t.annotate(p)
示例#10
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition,
                                     ext_circuit: CircuitHelper,
                                     original_params: List[Parameter],
                                     requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        has_priv_args = any(
            [p.annotated_type.is_cipher() for p in original_params])
        stmts = []

        if has_priv_args:
            ext_circuit._require_public_key_for_label_at(
                None, Expression.me_expr())
        if cfg.is_symmetric_cipher():
            # Make sure msg.sender's key pair is available in the circuit
            assert any(isinstance(k, MeExpr) for k in ext_circuit.requested_global_keys) \
                   or has_priv_args, "requires verification => both sender keys required"
            stmts += ext_circuit.request_private_key()

        # Verify that out parameter has correct size
        stmts += [
            RequireStatement(
                IdentifierExpr(cfg.zk_out_name).dot('length').binop(
                    '==', NumberLiteralExpr(ext_circuit.out_size_trans)))
        ]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(
            Array(AnnotatedTypeName.uint_all()))

        # Find index of me's public key in requested_global_keys
        glob_me_key_index = -1
        for idx, e in enumerate(ext_circuit.requested_global_keys):
            if isinstance(e, MeExpr):
                glob_me_key_index = idx
                break

        # Request static public keys
        offset = 0
        key_req_stmts = []
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]
            if glob_me_key_index != -1:
                (keys[0], keys[glob_me_key_index]) = (keys[glob_me_key_index],
                                                      keys[0])

            tmp_key_var = Identifier('_tmp_key')
            key_req_stmts.append(
                tmp_key_var.decl_var(AnnotatedTypeName.key_type()))
            for key_owner in keys:
                idf, assignment = ext_circuit.request_public_key(
                    key_owner, ext_circuit.get_glob_key_name(key_owner))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Manually add to circuit inputs
                key_req_stmts.append(
                    in_arr_var.slice(offset, cfg.key_len).assign(
                        IdentifierExpr(tmp_key_var.clone()).slice(
                            0, cfg.key_len)))
                offset += cfg.key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                assign_stmt = in_arr_var.slice(
                    offset, cfg.cipher_payload_len).assign(
                        IdentifierExpr(p.idf.clone()).slice(
                            0, cfg.cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cfg.cipher_payload_len

        if cfg.is_symmetric_cipher():
            # Populate sender field of encrypted parameters
            copy_stmts = []
            for p in original_params:
                if p.annotated_type.is_cipher():
                    sender_key = in_arr_var.index(0)
                    idf = IdentifierExpr(p.idf.clone()).as_type(
                        p.annotated_type.clone())
                    lit = ArrayLiteralExpr([
                        idf.clone().index(i)
                        for i in range(cfg.cipher_payload_len)
                    ] + [sender_key])
                    copy_stmts.append(
                        VariableDeclarationStatement(
                            VariableDeclaration([], p.annotated_type.clone(),
                                                p.idf.clone(), 'memory'), lit))
            if copy_stmts:
                param_stmts += [
                    Comment(),
                    Comment(
                        'Copy from calldata to memory and set sender field')
                ] + copy_stmts

            assert glob_me_key_index != -1, "Symmetric cipher but did not request me key"

        # Declare in array
        new_in_array_expr = NewExpr(
            AnnotatedTypeName(TypeName.dyn_uint_array()),
            [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(),
                                              new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys',
                                            key_req_stmts)
        stmts += Comment.comment_wrap_block(
            'Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(
            IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [
                in_arr_var.clone(),
                NumberLiteralExpr(ext_circuit.in_size),
                IdentifierExpr(cfg.zk_out_name),
                NumberLiteralExpr(ext_circuit.out_size)
            ]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [
                VariableDeclarationStatement(deep_copy(vd))
                for vd in int_fct.return_var_decls
            ])
            in_call = TupleExpr([
                IdentifierExpr(vd.idf.clone())
                for vd in int_fct.return_var_decls
            ]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof:
            verifier = IdentifierExpr(
                cfg.get_contract_var_name(
                    ext_circuit.verifier_contract_type.code()))
            verifier_args = [
                IdentifierExpr(cfg.proof_param_name),
                IdentifierExpr(cfg.zk_in_name),
                IdentifierExpr(cfg.zk_out_name)
            ]
            verify = ExpressionStatement(
                verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(
                StatementList(
                    [Comment('Verify zk proof of execution'), verify],
                    excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(
                ReturnStatement(
                    TupleExpr([
                        IdentifierExpr(vd.idf.clone())
                        for vd in int_fct.return_var_decls
                    ])))

        return Block(stmts)
示例#11
0
文件: build_ast.py 项目: nibau/zkay
 def visitIteExpr(self, ctx: SolidityParser.IteExprContext):
     f = BuiltinFunction('ite')
     cond = self.visit(ctx.cond)
     then_expr = self.visit(ctx.then_expr)
     else_expr = self.visit(ctx.else_expr)
     return FunctionCallExpr(f, [cond, then_expr, else_expr])
示例#12
0
文件: build_ast.py 项目: nibau/zkay
 def _visitBinaryExpr(self, ctx):
     lhs = self.visit(ctx.lhs)
     rhs = self.visit(ctx.rhs)
     f = BuiltinFunction(ctx.op.text).override(line=ctx.op.line, column=ctx.op.column)
     return FunctionCallExpr(f, [lhs, rhs])
示例#13
0
文件: build_ast.py 项目: nibau/zkay
 def visitBitwiseNotExpr(self, ctx: SolidityParser.BitwiseNotExprContext):
     f = BuiltinFunction('~').override(line=ctx.start.line, column=ctx.start.column)
     expr = self.visit(ctx.expr)
     return FunctionCallExpr(f, [expr])
示例#14
0
文件: build_ast.py 项目: nibau/zkay
 def visitSignExpr(self, ctx: SolidityParser.SignExprContext):
     f = BuiltinFunction('sign' + ctx.op.text).override(line=ctx.op.line, column=ctx.op.column)
     expr = self.visit(ctx.expr)
     return FunctionCallExpr(f, [expr])
示例#15
0
文件: build_ast.py 项目: nibau/zkay
 def visitParenthesisExpr(self, ctx: SolidityParser.ParenthesisExprContext):
     f = BuiltinFunction('parenthesis').override(line=ctx.start.line, column=ctx.start.column)
     expr = self.visit(ctx.expr)
     return FunctionCallExpr(f, [expr])
示例#16
0
文件: test_ast.py 项目: nibau/zkay
 def test_builtin_code(self):
     f = BuiltinFunction('+')
     c = FunctionCallExpr(f, [NumberLiteralExpr(0), NumberLiteralExpr(0)])
     self.assertEqual(c.code(), '0 + 0')
示例#17
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition, ext_circuit: CircuitHelper,
                                     original_params: List[Parameter], requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        priv_args = [p for p in original_params if p.annotated_type.is_cipher()]
        args_backends = OrderedDict.fromkeys([p.annotated_type.type_name.crypto_params for p in priv_args])
        stmts = []

        for crypto_params in args_backends:
            assert crypto_params in int_fct.used_crypto_backends
            # If there are any private arguments with homomorphism 'hom', we need the public key for that crypto backend
            ext_circuit._require_public_key_for_label_at(None, Expression.me_expr(), crypto_params)
        for crypto_params in cfg.all_crypto_params():
            if crypto_params.is_symmetric_cipher():
                if (MeExpr(), crypto_params) in ext_circuit.requested_global_keys or crypto_params in args_backends:
                    # Make sure msg.sender's key pair is available in the circuit
                    stmts += ext_circuit.request_private_key(crypto_params)

        # Verify that out parameter has correct size
        stmts += [RequireStatement(IdentifierExpr(cfg.zk_out_name).dot('length').binop('==', NumberLiteralExpr(ext_circuit.out_size_trans)))]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))

        # Request static public keys
        offset = 0
        key_req_stmts = []
        me_key_idx: Dict[CryptoParams, int] = {}
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]

            tmp_keys = {}
            for crypto_params in int_fct.used_crypto_backends:
                tmp_key_var = Identifier(f'_tmp_key_{crypto_params.identifier_name}')
                key_req_stmts.append(tmp_key_var.decl_var(AnnotatedTypeName.key_type(crypto_params)))
                tmp_keys[crypto_params] = tmp_key_var
            for (key_owner, crypto_params) in keys:
                tmp_key_var = tmp_keys[crypto_params]
                idf, assignment = ext_circuit.request_public_key(crypto_params, key_owner, ext_circuit.get_glob_key_name(key_owner, crypto_params))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Remember me-keys for later use in symmetrically encrypted keys
                if key_owner == MeExpr():
                    assert crypto_params not in me_key_idx
                    me_key_idx[crypto_params] = offset

                # Manually add to circuit inputs
                key_len = crypto_params.key_len
                key_req_stmts.append(in_arr_var.slice(offset, key_len).assign(IdentifierExpr(tmp_key_var.clone()).slice(0, key_len)))
                offset += key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                cipher_payload_len = p.annotated_type.type_name.crypto_params.cipher_payload_len
                assign_stmt = in_arr_var.slice(offset, cipher_payload_len).assign(IdentifierExpr(p.idf.clone()).slice(0, cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cipher_payload_len

        # Populate sender field of parameters encrypted with a symmetric cipher
        copy_stmts = []
        for p in original_params:
            if p.annotated_type.is_cipher():
                c = p.annotated_type.type_name
                assert isinstance(c, CipherText)
                if c.crypto_params.is_symmetric_cipher():
                    sender_key = in_arr_var.index(me_key_idx[c.crypto_params])
                    idf = IdentifierExpr(p.idf.clone()).as_type(p.annotated_type.clone())
                    cipher_payload_len = cfg.get_crypto_params(p.annotated_type.homomorphism).cipher_payload_len
                    lit = ArrayLiteralExpr([idf.clone().index(i) for i in range(cipher_payload_len)] + [sender_key])
                    copy_stmts.append(VariableDeclarationStatement(VariableDeclaration([], p.annotated_type.clone(), p.idf.clone(), 'memory'), lit))
        if copy_stmts:
            param_stmts += [Comment(), Comment('Copy from calldata to memory and set sender field')] + copy_stmts

        # Declare in array
        new_in_array_expr = NewExpr(AnnotatedTypeName(TypeName.dyn_uint_array()), [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(), new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys', key_req_stmts)
        stmts += Comment.comment_wrap_block('Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [in_arr_var.clone(), NumberLiteralExpr(ext_circuit.in_size),
                     IdentifierExpr(cfg.zk_out_name), NumberLiteralExpr(ext_circuit.out_size)]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(deep_copy(vd)) for vd in int_fct.return_var_decls])
            in_call = TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof and not cfg.disable_verification:
            verifier = IdentifierExpr(cfg.get_contract_var_name(ext_circuit.verifier_contract_type.code()))
            verifier_args = [IdentifierExpr(cfg.proof_param_name), IdentifierExpr(cfg.zk_in_name), IdentifierExpr(cfg.zk_out_name)]
            verify = ExpressionStatement(verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(StatementList([Comment('Verify zk proof of execution'), verify], excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls])))

        return Block(stmts)
示例#18
0
    def evaluate_stmt_in_circuit(self, ast: Statement) -> AssignmentStatement:
        """
        Evaluate an entire statement privately.

        This works by turning the statement into an assignment statement where the

        * lhs is a tuple of all external locations (defined outside statement), which are modified inside the statement
        * rhs is the return value of an inlined function call expression to a virtual function where body = the statement + return statement \
          which returns a tuple of the most recent SSA version of all modified locations

        Note: Modifying external locations which are not owned by @me inside the statement is illegal (would leak information).
        Note: At the moment, this is only used for if statements with a private condition.

        :param ast: the statement to evaluate inside the circuit
        :return: AssignmentStatement as described above
        """
        astmt = ExpressionStatement(NumberLiteralExpr(0))
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                astmt = AssignmentStatement(None, None)
                break

        astmt.before_analysis = ast.before_analysis

        # External values written inside statement -> function return values
        ret_params = []
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                # side effect affects location outside statement and has privacy @me
                assert ast.before_analysis.same_partition(
                    var.privacy, Expression.me_expr())
                assert isinstance(
                    var.target,
                    (Parameter, VariableDeclaration, StateVariableDeclaration))
                t = var.target.annotated_type.zkay_type
                if not t.type_name.is_primitive_type():
                    raise NotImplementedError(
                        'Reference types inside private if statements are not supported'
                    )
                ret_t = AnnotatedTypeName(t.type_name, Expression.me_expr(),
                                          t.homomorphism)  # t, but @me
                ret_param = IdentifierExpr(var.target.idf.clone(),
                                           ret_t).override(target=var.target)
                ret_param.statement = astmt
                ret_params.append(ret_param)

        # Build the imaginary function
        fdef = ConstructorOrFunctionDefinition(
            Identifier('<stmt_fct>'), [], ['private'], [
                Parameter([], ret.annotated_type, ret.target.idf)
                for ret in ret_params
            ], Block([ast, ReturnStatement(TupleExpr(ret_params))]))
        fdef.original_body = fdef.body
        fdef.body.parent = fdef
        fdef.parent = ast

        # inline "Call" to the imaginary function
        fcall = FunctionCallExpr(
            IdentifierExpr('<stmt_fct>').override(target=fdef), [])
        fcall.statement = astmt
        ret_args = self.inline_function_call_into_circuit(fcall)

        # Move all return values out of the circuit
        if not isinstance(ret_args, TupleExpr):
            ret_args = TupleExpr([ret_args])
        for ret_arg in ret_args.elements:
            ret_arg.statement = astmt
        ret_arg_outs = [
            self._get_circuit_output_for_private_expression(
                ret_arg, Expression.me_expr(),
                ret_param.annotated_type.homomorphism)
            for ret_param, ret_arg in zip(ret_params, ret_args.elements)
        ]

        # Create assignment statement
        if ret_params:
            astmt.lhs = TupleExpr(
                [ret_param.clone() for ret_param in ret_params])
            astmt.rhs = TupleExpr(ret_arg_outs)
            return astmt
        else:
            assert isinstance(astmt, ExpressionStatement)
            return astmt
示例#19
0
 def join(then_idf, else_idf):
     """Return new temporary HybridArgumentIdf with value cond ? then_idf : else_idf."""
     rhs = FunctionCallExpr(BuiltinFunction('ite'), [true_cond_for_other_branch.clone(), then_idf, else_idf]).as_type(val.t)
     return create_val_for_name_and_expr_fct(key.name, rhs)
示例#20
0
    def visitFunctionCallExpr(self, ast: FunctionCallExpr):
        if isinstance(ast.func, BuiltinFunction):
            if ast.func.is_private:
                """
                Modified Rule (12) builtin functions with private operands are evaluated inside the circuit.

                A private expression on its own (like an IdentifierExpr referring to a private variable) is not enough to trigger a
                boundary crossing (assignment of private variables is a public operation).
                """
                return self.gen.evaluate_expr_in_circuit(ast, Expression.me_expr())
            else:
                """
                Rule (10) with additional short-circuit handling.

                Builtin operations on public operands are normally left untransformed, but if the builtin function has
                short-circuiting semantics, guard conditions must be added if any of the public operands contains
                nested private expressions.
                """
                # handle short-circuiting
                if ast.func.has_shortcircuiting() and any(map(contains_private_expr, ast.args[1:])):
                    op = ast.func.op
                    guard_var = self.gen.add_to_circuit_inputs(ast.args[0])
                    ast.args[0] = guard_var.get_loc_expr(ast)
                    if op == 'ite':
                        ast.args[1] = self.visit_guarded_expression(guard_var, True, ast.args[1])
                        ast.args[2] = self.visit_guarded_expression(guard_var, False, ast.args[2])
                    elif op == '||':
                        ast.args[1] = self.visit_guarded_expression(guard_var, False, ast.args[1])
                    elif op == '&&':
                        ast.args[1] = self.visit_guarded_expression(guard_var, True, ast.args[1])
                    return ast

                return self.visit_children(ast)
        elif ast.is_cast:
            """Casts are handled either in public or inside the circuit depending on the privacy of the casted expression."""
            assert isinstance(ast.func.target, EnumDefinition)
            if ast.args[0].evaluate_privately:
                return self.gen.evaluate_expr_in_circuit(ast, Expression.me_expr())
            else:
                return self.visit_children(ast)
        else:
            """
            Handle normal function calls (outside private expression case).

            The called functions are allowed to have side effects,
            if the function does not require verification it can even be recursive.
            """
            assert isinstance(ast.func, LocationExpr)
            ast = self.visit_children(ast)
            if ast.func.target.requires_verification_when_external:
                # Reroute the function call to the corresponding internal function if the called function was split into external/internal.
                if not isinstance(ast.func, IdentifierExpr):
                    raise NotImplementedError()
                ast.func.idf.name = cfg.get_internal_name(ast.func.target)

            if ast.func.target.requires_verification:
                # If the target function has an associated circuit, make this function's circuit aware of the call.
                self.gen.call_function(ast)
            elif ast.func.target.has_side_effects and self.gen is not None:
                # Invalidate modified state variables for the current circuit
                for val in ast.modified_values:
                    if val.key is None and isinstance(val.target, StateVariableDeclaration):
                        self.gen.invalidate_idf(val.target.idf)

            # The call will be present as a normal function call in the output solidity code.
            return ast