예제 #1
0
    def visitIdentifierExpr(self, ast: IdentifierExpr):
        # Special identifiers
        pki_inst_names = {
            f'{cfg.get_pki_contract_name(params)}_inst': params
            for params in cfg.all_crypto_params()
        }
        if ast.idf.name in pki_inst_names and not ast.is_lvalue():
            crypto_params = pki_inst_names[ast.idf.name]
            return f'{api("get_keystore")}("{crypto_params.crypto_name}")'
        elif ast.idf.name == cfg.field_prime_var_name:
            assert ast.is_rvalue()
            return f'{SCALAR_FIELD_NAME}'

        if self.current_index:
            # This identifier is the beginning of an Index expression e.g. idf[1][2] or idf[me]
            indices, t = list(reversed(
                self.current_index)), self.current_index_t
            self.current_index, self.current_index_t = [], None
            indices = [self.visit(idx) for idx in indices]
        elif self.inside_circuit and isinstance(
                ast.idf, HybridArgumentIdf
        ) and ast.idf.corresponding_priv_expression is not None and self.flatten_hybrid_args:
            return self.visit(ast.idf.corresponding_priv_expression)
        else:
            indices, t = [], ast.target.annotated_type if isinstance(
                ast.target, StateVariableDeclaration) else None

        return self.get_value(ast, indices)
예제 #2
0
def transform_internal_calls(
        fcts_with_verification: List[ConstructorOrFunctionDefinition],
        cgens: Dict[ConstructorOrFunctionDefinition, CircuitHelper]):
    """
    Add required additional args for public calls to functions which require verification.

    This must be called after compute_transitive_circuit_io_sizes.

    Whenever a function which requires verification is called, the caller needs to pass along the circuit input and output arrays,
    as well as the correct start indices for them, such that the callee deserializes/serializes from/into the correct segment of the
    output/input array. This function thus transforms function calls to functions requiring verification, by adding these additional
    arguments. This must be done in a second pass, after all function bodies in the contract are fully transformed,
    since the correct start indices depend on the circuit IO sizes of the caller function
    (see ZkayTransformer documentation for more information).

    :param fcts_with_verification: [SIDE EFFECT] All functions which have a circuit associated with them
    :param cgens: A map from function to circuit
    """
    for fct in fcts_with_verification:
        circuit = cgens[fct]
        i, o, p = 0, 0, 0
        for fc in circuit.function_calls_with_verification:
            fdef = fc.func.target
            fc.sec_start_offset = circuit.priv_in_size + p
            fc.args += [
                IdentifierExpr(cfg.zk_in_name),
                IdentifierExpr(f'{cfg.zk_in_name}_start_idx').binop(
                    '+', NumberLiteralExpr(circuit.in_size + i)),
                IdentifierExpr(cfg.zk_out_name),
                IdentifierExpr(f'{cfg.zk_out_name}_start_idx').binop(
                    '+', NumberLiteralExpr(circuit.out_size + o))
            ]
            i, o, p = i + cgens[fdef].in_size_trans, o + cgens[
                fdef].out_size_trans, p + cgens[fdef].priv_in_size_trans
        assert i == circuit.trans_in_size and o == circuit.trans_out_size and p == circuit.trans_priv_size
예제 #3
0
파일: final_checker.py 프로젝트: nibau/zkay
 def visitIdentifierExpr(self, ast: IdentifierExpr):
     if ast.is_rvalue() and self.state_vars_assigned is not None:
         if ast.target in self.state_vars_assigned and not self.state_vars_assigned[
                 ast.target]:
             raise TypeException(
                 f'{str(ast)} is reading "final" state variable before writing it',
                 ast)
예제 #4
0
    def request_public_key(self, crypto_params: CryptoParams,
                           plabel: Union[MeExpr, Identifier], name: str):
        """
        Request key for the address corresponding to plabel from pki infrastructure and add it to the public circuit inputs.

        :param plabel: privacy label for which to request key
        :param name: name to use for the HybridArgumentIdf holding the key
        :return: HybridArgumentIdf containing the requested key and an AssignmentStatement which assigns the key request to the idf location
        """
        idf = self._in_name_factory.add_idf(name,
                                            TypeName.key_type(crypto_params))
        pki = IdentifierExpr(
            cfg.get_contract_var_name(
                cfg.get_pki_contract_name(crypto_params)))
        privacy_label_expr = get_privacy_expr_from_label(plabel)
        return idf, idf.get_loc_expr().assign(
            pki.call('getPk', [self._expr_trafo.visit(privacy_label_expr)]))
예제 #5
0
파일: test_ast.py 프로젝트: nibau/zkay
 def test_assignment_statement(self):
     i = Identifier('x')
     lhs = IdentifierExpr(i)
     rhs = BooleanLiteralExpr(True)
     a = AssignmentStatement(lhs, rhs)
     self.assertIsNotNone(a)
     self.assertEqual(str(a), 'x = true;')
     self.assertEqual(a.children(), [lhs, rhs])
     self.assertDictEqual(a.names, {})
     self.assertIsNone(a.parent)
예제 #6
0
파일: type_checker.py 프로젝트: nibau/zkay
    def visitIdentifierExpr(self, ast: IdentifierExpr):
        if isinstance(ast.target, Mapping):
            # no action necessary, the identifier will be replaced later
            pass
        else:
            target = ast.target
            if isinstance(target, ContractDefinition):
                raise TypeException(
                    f'Unsupported use of contract type in expression', ast)
            ast.annotated_type = target.annotated_type.clone()

            if not self.is_accessible_by_invoker(ast):
                raise TypeException(
                    "Tried to read value which cannot be proven to be owned by the transaction invoker",
                    ast)
예제 #7
0
    def visitReturnStatement(self, ast: ReturnStatement):
        """
        Handle return statement.

        If the function requires verification, the return statement is replaced by an assignment to a return variable.
        (which will be returned at the very end of the function body, after any verification wrapper code).
        Otherwise only the expression is transformed.
        """
        if ast.function.requires_verification:
            if ast.expr is None:
                return None
            assert not self.gen.has_return_var
            self.gen.has_return_var = True
            expr = self.expr_trafo.visit(ast.expr)
            ret_args = [IdentifierExpr(vd.idf.clone()).override(target=vd) for vd in ast.function.return_var_decls]
            return TupleExpr(ret_args).assign(expr).override(pre_statements=ast.pre_statements)
        else:
            ast.expr = self.expr_trafo.visit(ast.expr)
            return ast
예제 #8
0
    def inline_function_call_into_circuit(
            self, fcall: FunctionCallExpr) -> Union[Expression, TupleExpr]:
        """
        Inline an entire function call into the current circuit.

        :param fcall: Function call to inline
        :return: Expression (1 retval) / TupleExpr (multiple retvals) with return value(s)
        """
        assert isinstance(fcall.func,
                          LocationExpr) and fcall.func.target is not None
        fdef = fcall.func.target
        with self._remapper.remap_scope(fcall.func.target.body):
            with nullcontext(
            ) if fcall.func.target.idf.name == '<stmt_fct>' else self.circ_indent_block(
                    f'INLINED {fcall.code()}'):
                # Assign all arguments to temporary circuit variables which are designated as the current version of the parameter idfs
                for param, arg in zip(fdef.parameters, fcall.args):
                    self.phi.append(
                        CircComment(f'ARG {param.idf.name}: {arg.code()}'))
                    with self.circ_indent_block():
                        self.create_new_idf_version_from_value(param.idf, arg)

                # Visit the untransformed target function body to include all statements in this circuit
                inlined_body = deep_copy(fdef.original_body,
                                         with_types=True,
                                         with_analysis=True)
                self._circ_trafo.visit(inlined_body)
                fcall.statement.pre_statements += inlined_body.pre_statements

                # Create TupleExpr with location expressions corresponding to the function return values as elements
                ret_idfs = [
                    self._remapper.get_current(vd.idf)
                    for vd in fdef.return_var_decls
                ]
                ret = TupleExpr([
                    IdentifierExpr(idf.clone()).as_type(idf.t)
                    for idf in ret_idfs
                ])
        if len(ret.elements) == 1:
            # Unpack 1-length tuple
            ret = ret.elements[0]
        return ret
예제 #9
0
 def visitMeExpr(ast: MeExpr):
     """Replace me with msg.sender."""
     return replace_expr(ast, IdentifierExpr('msg').dot('sender')).as_type(AnnotatedTypeName.address_all())
예제 #10
0
    def add_to_circuit_inputs(self, expr: Expression) -> HybridArgumentIdf:
        """
        Add the provided expression to the public circuit inputs.

        Roughly corresponds to in() from paper

        If expr is encrypted (privacy != @all), this function also automatically ensures that the circuit has access to
        the correctly decrypted expression value in the form of a new private circuit input.

        If expr is an IdentifierExpr, its value will be cached
        (i.e. when the same identifier is needed again as a circuit input, its value will be retrieved from cache rather \
         than adding an expensive redundant input. The cache is invalidated as soon as the identifier is overwritten in public code)

        Note: This function has side effects on expr.statement (adds a pre_statement)

        :param expr: [SIDE EFFECT] expression which should be made available inside the circuit as an argument
        :return: HybridArgumentIdf which references the plaintext value of the newly added input
        """
        privacy = expr.annotated_type.privacy_annotation.privacy_annotation_label(
        ) if expr.annotated_type.is_private() else Expression.all_expr()
        is_public = privacy == Expression.all_expr()

        expr_text = expr.code()
        input_expr = self._expr_trafo.visit(expr)
        t = input_expr.annotated_type.type_name
        locally_decrypted_idf = None

        # If expression has literal type -> evaluate it inside the circuit (constant folding will be used)
        # rather than introducing an unnecessary public circuit input (expensive)
        if isinstance(t, BooleanLiteralType):
            return self._evaluate_private_expression(input_expr, str(t.value))
        elif isinstance(t, NumberLiteralType):
            return self._evaluate_private_expression(input_expr, str(t.value))

        t_suffix = ''
        if isinstance(expr, IdentifierExpr):
            # Look in cache before doing expensive move-in
            if self._remapper.is_remapped(expr.target.idf):
                remapped_idf = self._remapper.get_current(expr.target.idf)
                return remapped_idf

            t_suffix = f'_{expr.idf.name}'

        # Generate circuit inputs
        if is_public:
            tname = f'{self._in_name_factory.get_new_name(expr.annotated_type.type_name)}{t_suffix}'
            return_idf = input_idf = self._in_name_factory.add_idf(
                tname, expr.annotated_type.type_name)
            self._phi.append(CircComment(f'{input_idf.name} = {expr_text}'))
        else:
            # Encrypted inputs need to be decrypted inside the circuit (i.e. add plain as private input and prove encryption)
            tname = f'{self._secret_input_name_factory.get_new_name(expr.annotated_type.type_name)}{t_suffix}'
            return_idf = locally_decrypted_idf = self._secret_input_name_factory.add_idf(
                tname, expr.annotated_type.type_name)
            cipher_t = TypeName.cipher_type(input_expr.annotated_type,
                                            expr.annotated_type.homomorphism)
            tname = f'{self._in_name_factory.get_new_name(cipher_t)}{t_suffix}'
            input_idf = self._in_name_factory.add_idf(
                tname, cipher_t, IdentifierExpr(locally_decrypted_idf))

        # Add a CircuitInputStatement to the solidity code, which looks like a normal assignment statement,
        # but also signals the offchain simulator to perform decryption if necessary
        expr.statement.pre_statements.append(
            CircuitInputStatement(input_idf.get_loc_expr(), input_expr))

        if not is_public:
            # Check if the secret plain input corresponds to the decrypted cipher value
            crypto_params = cfg.get_crypto_params(
                expr.annotated_type.homomorphism)
            self._phi.append(
                CircComment(
                    f'{locally_decrypted_idf} = dec({expr_text}) [{input_idf.name}]'
                ))
            self._ensure_encryption(expr.statement, locally_decrypted_idf,
                                    Expression.me_expr(), crypto_params,
                                    input_idf, False, True)

        # Cache circuit input for later reuse if possible
        if cfg.opt_cache_circuit_inputs and isinstance(expr, IdentifierExpr):
            # TODO: What if a homomorphic variable gets used as both a plain variable and as a ciphertext?
            #       This works for now because we never perform homomorphic operations on variables we can decrypt.
            self._remapper.remap(expr.target.idf, return_idf)

        return return_idf
예제 #11
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition,
                                     ext_circuit: CircuitHelper,
                                     original_params: List[Parameter],
                                     requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        has_priv_args = any(
            [p.annotated_type.is_cipher() for p in original_params])
        stmts = []

        if has_priv_args:
            ext_circuit._require_public_key_for_label_at(
                None, Expression.me_expr())
        if cfg.is_symmetric_cipher():
            # Make sure msg.sender's key pair is available in the circuit
            assert any(isinstance(k, MeExpr) for k in ext_circuit.requested_global_keys) \
                   or has_priv_args, "requires verification => both sender keys required"
            stmts += ext_circuit.request_private_key()

        # Verify that out parameter has correct size
        stmts += [
            RequireStatement(
                IdentifierExpr(cfg.zk_out_name).dot('length').binop(
                    '==', NumberLiteralExpr(ext_circuit.out_size_trans)))
        ]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(
            Array(AnnotatedTypeName.uint_all()))

        # Find index of me's public key in requested_global_keys
        glob_me_key_index = -1
        for idx, e in enumerate(ext_circuit.requested_global_keys):
            if isinstance(e, MeExpr):
                glob_me_key_index = idx
                break

        # Request static public keys
        offset = 0
        key_req_stmts = []
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]
            if glob_me_key_index != -1:
                (keys[0], keys[glob_me_key_index]) = (keys[glob_me_key_index],
                                                      keys[0])

            tmp_key_var = Identifier('_tmp_key')
            key_req_stmts.append(
                tmp_key_var.decl_var(AnnotatedTypeName.key_type()))
            for key_owner in keys:
                idf, assignment = ext_circuit.request_public_key(
                    key_owner, ext_circuit.get_glob_key_name(key_owner))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Manually add to circuit inputs
                key_req_stmts.append(
                    in_arr_var.slice(offset, cfg.key_len).assign(
                        IdentifierExpr(tmp_key_var.clone()).slice(
                            0, cfg.key_len)))
                offset += cfg.key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                assign_stmt = in_arr_var.slice(
                    offset, cfg.cipher_payload_len).assign(
                        IdentifierExpr(p.idf.clone()).slice(
                            0, cfg.cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cfg.cipher_payload_len

        if cfg.is_symmetric_cipher():
            # Populate sender field of encrypted parameters
            copy_stmts = []
            for p in original_params:
                if p.annotated_type.is_cipher():
                    sender_key = in_arr_var.index(0)
                    idf = IdentifierExpr(p.idf.clone()).as_type(
                        p.annotated_type.clone())
                    lit = ArrayLiteralExpr([
                        idf.clone().index(i)
                        for i in range(cfg.cipher_payload_len)
                    ] + [sender_key])
                    copy_stmts.append(
                        VariableDeclarationStatement(
                            VariableDeclaration([], p.annotated_type.clone(),
                                                p.idf.clone(), 'memory'), lit))
            if copy_stmts:
                param_stmts += [
                    Comment(),
                    Comment(
                        'Copy from calldata to memory and set sender field')
                ] + copy_stmts

            assert glob_me_key_index != -1, "Symmetric cipher but did not request me key"

        # Declare in array
        new_in_array_expr = NewExpr(
            AnnotatedTypeName(TypeName.dyn_uint_array()),
            [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(),
                                              new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys',
                                            key_req_stmts)
        stmts += Comment.comment_wrap_block(
            'Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(
            IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [
                in_arr_var.clone(),
                NumberLiteralExpr(ext_circuit.in_size),
                IdentifierExpr(cfg.zk_out_name),
                NumberLiteralExpr(ext_circuit.out_size)
            ]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [
                VariableDeclarationStatement(deep_copy(vd))
                for vd in int_fct.return_var_decls
            ])
            in_call = TupleExpr([
                IdentifierExpr(vd.idf.clone())
                for vd in int_fct.return_var_decls
            ]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof:
            verifier = IdentifierExpr(
                cfg.get_contract_var_name(
                    ext_circuit.verifier_contract_type.code()))
            verifier_args = [
                IdentifierExpr(cfg.proof_param_name),
                IdentifierExpr(cfg.zk_in_name),
                IdentifierExpr(cfg.zk_out_name)
            ]
            verify = ExpressionStatement(
                verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(
                StatementList(
                    [Comment('Verify zk proof of execution'), verify],
                    excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(
                ReturnStatement(
                    TupleExpr([
                        IdentifierExpr(vd.idf.clone())
                        for vd in int_fct.return_var_decls
                    ])))

        return Block(stmts)
예제 #12
0
    def evaluate_stmt_in_circuit(self, ast: Statement) -> AssignmentStatement:
        """
        Evaluate an entire statement privately.

        This works by turning the statement into an assignment statement where the

        * lhs is a tuple of all external locations (defined outside statement), which are modified inside the statement
        * rhs is the return value of an inlined function call expression to a virtual function where body = the statement + return statement \
          which returns a tuple of the most recent SSA version of all modified locations

        Note: Modifying external locations which are not owned by @me inside the statement is illegal (would leak information).
        Note: At the moment, this is only used for if statements with a private condition.

        :param ast: the statement to evaluate inside the circuit
        :return: AssignmentStatement as described above
        """
        astmt = ExpressionStatement(NumberLiteralExpr(0))
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                astmt = AssignmentStatement(None, None)
                break

        astmt.before_analysis = ast.before_analysis

        # External values written inside statement -> function return values
        ret_params = []
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                # side effect affects location outside statement and has privacy @me
                assert ast.before_analysis.same_partition(
                    var.privacy, Expression.me_expr())
                assert isinstance(
                    var.target,
                    (Parameter, VariableDeclaration, StateVariableDeclaration))
                t = var.target.annotated_type.zkay_type
                if not t.type_name.is_primitive_type():
                    raise NotImplementedError(
                        'Reference types inside private if statements are not supported'
                    )
                ret_t = AnnotatedTypeName(t.type_name, Expression.me_expr(),
                                          t.homomorphism)  # t, but @me
                ret_param = IdentifierExpr(var.target.idf.clone(),
                                           ret_t).override(target=var.target)
                ret_param.statement = astmt
                ret_params.append(ret_param)

        # Build the imaginary function
        fdef = ConstructorOrFunctionDefinition(
            Identifier('<stmt_fct>'), [], ['private'], [
                Parameter([], ret.annotated_type, ret.target.idf)
                for ret in ret_params
            ], Block([ast, ReturnStatement(TupleExpr(ret_params))]))
        fdef.original_body = fdef.body
        fdef.body.parent = fdef
        fdef.parent = ast

        # inline "Call" to the imaginary function
        fcall = FunctionCallExpr(
            IdentifierExpr('<stmt_fct>').override(target=fdef), [])
        fcall.statement = astmt
        ret_args = self.inline_function_call_into_circuit(fcall)

        # Move all return values out of the circuit
        if not isinstance(ret_args, TupleExpr):
            ret_args = TupleExpr([ret_args])
        for ret_arg in ret_args.elements:
            ret_arg.statement = astmt
        ret_arg_outs = [
            self._get_circuit_output_for_private_expression(
                ret_arg, Expression.me_expr(),
                ret_param.annotated_type.homomorphism)
            for ret_param, ret_arg in zip(ret_params, ret_args.elements)
        ]

        # Create assignment statement
        if ret_params:
            astmt.lhs = TupleExpr(
                [ret_param.clone() for ret_param in ret_params])
            astmt.rhs = TupleExpr(ret_arg_outs)
            return astmt
        else:
            assert isinstance(astmt, ExpressionStatement)
            return astmt
예제 #13
0
 def get_randomness_for_rerand(self, expr: Expression) -> IdentifierExpr:
     idf = self._secret_input_name_factory.get_new_idf(
         TypeName.rnd_type(expr.annotated_type.type_name.crypto_params))
     return IdentifierExpr(idf)
예제 #14
0
파일: name_remapper.py 프로젝트: nibau/zkay
    def join_branch(self, stmt, true_cond_for_other_branch: IdentifierExpr,
                    other_branch_state: Any, create_val_for_name_and_expr_fct: Callable[[K, Expression], V]):
        """
        Perform an SSA join for two branches.

        | i.e. if key is not remapped in any branch -> keep previous remapping
        |      if key is altered in at least one branch -> remap to conditional assignment of latest remapped version in either branch

        :param stmt: the branch statement, variables which are not already in scope at that statement are not included in the joined state
        :param true_cond_for_other_branch: IdentifierExpression which evaluates to true at runtime if other_branch is taken
        :param other_branch_state: remap state at the end of other branch (obtained using get_state)
        :param create_val_for_name_and_expr_fct: function to introduce a new temporary variable to which the given expression is assigned

        :Example use:

        ::

            with remapper.remap_scope(persist_globals=False):
                <process true branch>
                true_state = remapper.get_state()
            if <has false branch>:
                <process false branch>
            remapper.join_branch(cond_idf_expr, true_state, <create_tmp_var(idf, expr) function>)
        """
        true_state = other_branch_state
        false_state = self.rmap
        self.rmap = {}

        def join(then_idf, else_idf):
            """Return new temporary HybridArgumentIdf with value cond ? then_idf : else_idf."""
            rhs = FunctionCallExpr(BuiltinFunction('ite'), [true_cond_for_other_branch.clone(), then_idf, else_idf]).as_type(val.t)
            return create_val_for_name_and_expr_fct(key.name, rhs)

        for key, val in true_state.items():
            if not SymbolTableLinker.in_scope_at(key, stmt):
                # Don't keep local values
                continue

            if key in false_state and false_state[key].name == val.name:
                # key was not modified in either branch -> simply keep
                assert false_state[key] == val
                self.rmap[key] = val
            elif key not in false_state:
                # If value was only read (remapping points to a circuit input) -> can just take as-is,
                # otherwise have to use conditional assignment
                if isinstance(val, HybridArgumentIdf) and (val.arg_type == HybridArgType.PUB_CIRCUIT_ARG or val.arg_type == HybridArgType.PRIV_CIRCUIT_VAL):
                    self.rmap[key] = val
                else:
                    # key was only modified in true branch
                    # remap key -> new temporary with value cond ? new_value : old_value
                    key_decl = key.parent
                    assert key_decl.annotated_type is not None
                    prev_val = IdentifierExpr(key.clone()).as_type(key_decl.annotated_type.zkay_type.clone())
                    prev_val = prev_val.override(target=key_decl, parent=stmt, statement=stmt)
                    self.rmap[key] = join(true_state[key].get_idf_expr(stmt), prev_val)
            else:
                # key was modified in both branches
                # remap key -> new temporary with value cond ? true_val : false_val
                self.rmap[key] = join(true_state[key].get_idf_expr(stmt), false_state[key].get_idf_expr(stmt))
        for key, val in false_state.items():
            if not SymbolTableLinker.in_scope_at(key, stmt):
                # Don't keep local values
                continue

            if key not in true_state:
                if isinstance(val, HybridArgumentIdf) and (val.arg_type == HybridArgType.PUB_CIRCUIT_ARG or val.arg_type == HybridArgType.PRIV_CIRCUIT_VAL):
                    self.rmap[key] = val
                else:
                    # key was only modified in false branch
                    # remap key -> new temporary with value cond ? old_value : new_value
                    key_decl = key.parent
                    assert key_decl.annotated_type is not None
                    prev_val = IdentifierExpr(key.clone()).as_type(key_decl.annotated_type.zkay_type.clone())
                    prev_val = prev_val.override(target=key_decl, parent=stmt, statement=stmt)
                    self.rmap[key] = join(prev_val, false_state[key].get_idf_expr(stmt))
예제 #15
0
파일: symbol_table.py 프로젝트: nibau/zkay
 def visitIdentifierExpr(self, ast: IdentifierExpr):
     decl = self.find_identifier_declaration(ast)
     ast.target = decl
     assert (ast.target is not None)
예제 #16
0
    def create_external_wrapper_body(int_fct: ConstructorOrFunctionDefinition, ext_circuit: CircuitHelper,
                                     original_params: List[Parameter], requires_proof: bool) -> Block:
        """
        Return Block with external wrapper function body.

        :param int_fct: corresponding internal function
        :param ext_circuit: [SIDE EFFECT] circuit helper of the external wrapper function
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :return: body with wrapper code
        """
        priv_args = [p for p in original_params if p.annotated_type.is_cipher()]
        args_backends = OrderedDict.fromkeys([p.annotated_type.type_name.crypto_params for p in priv_args])
        stmts = []

        for crypto_params in args_backends:
            assert crypto_params in int_fct.used_crypto_backends
            # If there are any private arguments with homomorphism 'hom', we need the public key for that crypto backend
            ext_circuit._require_public_key_for_label_at(None, Expression.me_expr(), crypto_params)
        for crypto_params in cfg.all_crypto_params():
            if crypto_params.is_symmetric_cipher():
                if (MeExpr(), crypto_params) in ext_circuit.requested_global_keys or crypto_params in args_backends:
                    # Make sure msg.sender's key pair is available in the circuit
                    stmts += ext_circuit.request_private_key(crypto_params)

        # Verify that out parameter has correct size
        stmts += [RequireStatement(IdentifierExpr(cfg.zk_out_name).dot('length').binop('==', NumberLiteralExpr(ext_circuit.out_size_trans)))]

        # IdentifierExpr for array var holding serialized public circuit inputs
        in_arr_var = IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))

        # Request static public keys
        offset = 0
        key_req_stmts = []
        me_key_idx: Dict[CryptoParams, int] = {}
        if ext_circuit.requested_global_keys:
            # Ensure that me public key is stored starting at in[0]
            keys = [key for key in ext_circuit.requested_global_keys]

            tmp_keys = {}
            for crypto_params in int_fct.used_crypto_backends:
                tmp_key_var = Identifier(f'_tmp_key_{crypto_params.identifier_name}')
                key_req_stmts.append(tmp_key_var.decl_var(AnnotatedTypeName.key_type(crypto_params)))
                tmp_keys[crypto_params] = tmp_key_var
            for (key_owner, crypto_params) in keys:
                tmp_key_var = tmp_keys[crypto_params]
                idf, assignment = ext_circuit.request_public_key(crypto_params, key_owner, ext_circuit.get_glob_key_name(key_owner, crypto_params))
                assignment.lhs = IdentifierExpr(tmp_key_var.clone())
                key_req_stmts.append(assignment)

                # Remember me-keys for later use in symmetrically encrypted keys
                if key_owner == MeExpr():
                    assert crypto_params not in me_key_idx
                    me_key_idx[crypto_params] = offset

                # Manually add to circuit inputs
                key_len = crypto_params.key_len
                key_req_stmts.append(in_arr_var.slice(offset, key_len).assign(IdentifierExpr(tmp_key_var.clone()).slice(0, key_len)))
                offset += key_len
                assert offset == ext_circuit.in_size

        # Check encrypted parameters
        param_stmts = []
        for p in original_params:
            """ * of T_e rule 8 """
            if p.annotated_type.is_cipher():
                cipher_payload_len = p.annotated_type.type_name.crypto_params.cipher_payload_len
                assign_stmt = in_arr_var.slice(offset, cipher_payload_len).assign(IdentifierExpr(p.idf.clone()).slice(0, cipher_payload_len))
                ext_circuit.ensure_parameter_encryption(assign_stmt, p)

                # Manually add to circuit inputs
                param_stmts.append(assign_stmt)
                offset += cipher_payload_len

        # Populate sender field of parameters encrypted with a symmetric cipher
        copy_stmts = []
        for p in original_params:
            if p.annotated_type.is_cipher():
                c = p.annotated_type.type_name
                assert isinstance(c, CipherText)
                if c.crypto_params.is_symmetric_cipher():
                    sender_key = in_arr_var.index(me_key_idx[c.crypto_params])
                    idf = IdentifierExpr(p.idf.clone()).as_type(p.annotated_type.clone())
                    cipher_payload_len = cfg.get_crypto_params(p.annotated_type.homomorphism).cipher_payload_len
                    lit = ArrayLiteralExpr([idf.clone().index(i) for i in range(cipher_payload_len)] + [sender_key])
                    copy_stmts.append(VariableDeclarationStatement(VariableDeclaration([], p.annotated_type.clone(), p.idf.clone(), 'memory'), lit))
        if copy_stmts:
            param_stmts += [Comment(), Comment('Copy from calldata to memory and set sender field')] + copy_stmts

        # Declare in array
        new_in_array_expr = NewExpr(AnnotatedTypeName(TypeName.dyn_uint_array()), [NumberLiteralExpr(ext_circuit.in_size_trans)])
        in_var_decl = in_arr_var.idf.decl_var(TypeName.dyn_uint_array(), new_in_array_expr)
        stmts.append(in_var_decl)
        stmts.append(Comment())
        stmts += Comment.comment_wrap_block('Request static public keys', key_req_stmts)
        stmts += Comment.comment_wrap_block('Backup private arguments for verification', param_stmts)

        # Call internal function
        args = [IdentifierExpr(param.idf.clone()) for param in original_params]
        internal_call = FunctionCallExpr(IdentifierExpr(int_fct.idf.clone()).override(target=int_fct), args)
        internal_call.sec_start_offset = ext_circuit.priv_in_size

        if int_fct.requires_verification:
            ext_circuit.call_function(internal_call)
            args += [in_arr_var.clone(), NumberLiteralExpr(ext_circuit.in_size),
                     IdentifierExpr(cfg.zk_out_name), NumberLiteralExpr(ext_circuit.out_size)]

        if int_fct.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(deep_copy(vd)) for vd in int_fct.return_var_decls])
            in_call = TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls]).assign(internal_call)
        else:
            in_call = ExpressionStatement(internal_call)
        stmts.append(Comment("Call internal function"))
        stmts.append(in_call)
        stmts.append(Comment())

        # Call verifier
        if requires_proof and not cfg.disable_verification:
            verifier = IdentifierExpr(cfg.get_contract_var_name(ext_circuit.verifier_contract_type.code()))
            verifier_args = [IdentifierExpr(cfg.proof_param_name), IdentifierExpr(cfg.zk_in_name), IdentifierExpr(cfg.zk_out_name)]
            verify = ExpressionStatement(verifier.call(cfg.verification_function_name, verifier_args))
            stmts.append(StatementList([Comment('Verify zk proof of execution'), verify], excluded_from_simulation=True))

        # Add return statement at the end if necessary
        if int_fct.return_parameters:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()) for vd in int_fct.return_var_decls])))

        return Block(stmts)
예제 #17
0
    def create_internal_verification_wrapper(self, ast: ConstructorOrFunctionDefinition):
        """
        Add the necessary additional parameters and boiler plate code for verification support to the given function.

        :param ast: [SIDE EFFECT] Internal function which requires verification
        """
        circuit = self.circuits[ast]
        stmts = []

        symmetric_cipher_used = any([backend.is_symmetric_cipher() for backend in ast.used_crypto_backends])
        if symmetric_cipher_used and 'pure' in ast.modifiers:
            # Symmetric trafo requires msg.sender access -> change from pure to view
            ast.modifiers = ['view' if mod == 'pure' else mod for mod in ast.modifiers]

        # Add additional params
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_in_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_in_name}_start_idx')
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_out_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_out_name}_start_idx')

        # Verify that in/out parameters have correct size
        out_start_idx, in_start_idx = IdentifierExpr(f'{cfg.zk_out_name}_start_idx'), IdentifierExpr(f'{cfg.zk_in_name}_start_idx')
        out_var, in_var = IdentifierExpr(cfg.zk_out_name), IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))
        stmts.append(RequireStatement(out_start_idx.binop('+', NumberLiteralExpr(circuit.out_size_trans)).binop('<=', out_var.dot('length'))))
        stmts.append(RequireStatement(in_start_idx.binop('+', NumberLiteralExpr(circuit.in_size_trans)).binop('<=', in_var.dot('length'))))

        # Declare zk_data struct var (if needed)
        if circuit.requires_zk_data_struct():
            zk_struct_type = StructTypeName([Identifier(circuit.zk_data_struct_name)])
            stmts += [Identifier(cfg.zk_data_var_name).decl_var(zk_struct_type), BlankLine()]

        # Declare return variable if necessary
        if ast.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(vd) for vd in ast.return_var_decls])

        # Find all me-keys in the in array
        me_key_idx: Dict[CryptoParams, int] = {}
        offset = 0
        for (key_owner, crypto_params) in circuit.requested_global_keys:
            if key_owner == MeExpr():
                assert crypto_params not in me_key_idx
                me_key_idx[crypto_params] = offset
            offset += crypto_params.key_len

        # Deserialize out array (if any)
        deserialize_stmts = []
        offset = 0
        for s in circuit.output_idfs:
            deserialize_stmts.append(s.deserialize(cfg.zk_out_name, out_start_idx, offset))
            if isinstance(s.t, CipherText) and s.t.crypto_params.is_symmetric_cipher():
                # Assign sender field to user-encrypted values if necessary
                # Assumption: s.t.crypto_params.key_len == 1 for all symmetric ciphers
                assert s.t.crypto_params in me_key_idx, "Symmetric cipher but did not request me key"
                key_idx = me_key_idx[s.t.crypto_params]
                sender_key = in_var.index(key_idx)
                cipher_payload_len = s.t.crypto_params.cipher_payload_len
                deserialize_stmts.append(s.get_loc_expr().index(cipher_payload_len).assign(sender_key))
            offset += s.t.size_in_uints
        if deserialize_stmts:
            stmts.append(StatementList(Comment.comment_wrap_block("Deserialize output values", deserialize_stmts), excluded_from_simulation=True))

        # Include original transformed function body
        stmts += ast.body.statements

        # Serialize in parameters to in array (if any)
        serialize_stmts = []
        offset = 0
        for s in circuit.input_idfs:
            serialize_stmts += [s.serialize(cfg.zk_in_name, in_start_idx, offset)]
            offset += s.t.size_in_uints
        if offset:
            stmts.append(Comment())
            stmts += Comment.comment_wrap_block('Serialize input values', serialize_stmts)

        # Add return statement at the end if necessary
        # (was previously replaced by assignment to return_var by ZkayStatementTransformer)
        if circuit.has_return_var:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()).override(target=vd) for vd in ast.return_var_decls])))

        ast.body.statements[:] = stmts