Ejemplo n.º 1
0
 def visitConstructorOrFunctionDefinition(
         self, ast: ConstructorOrFunctionDefinition):
     if not ast.requires_verification:
         for fct in ast.called_functions:
             if fct.requires_verification:
                 ast.requires_verification = True
                 if ast.can_be_external:
                     ast.requires_verification_when_external = True
                 break
Ejemplo n.º 2
0
    def visitConstructorOrFunctionDefinition(
            self, ast: ConstructorOrFunctionDefinition):
        self.visit(ast.body)

        if ast.can_be_external:
            if ast.requires_verification:
                ast.requires_verification_when_external = True
            else:
                for param in ast.parameters:
                    if param.annotated_type.is_private():
                        ast.requires_verification_when_external = True
                        break
Ejemplo n.º 3
0
 def visitConstructorOrFunctionDefinition(
         self, ast: ConstructorOrFunctionDefinition):
     if ast.can_be_private:
         for fct in ast.called_functions:
             if not fct.can_be_private:
                 ast.can_be_private = False
                 return
Ejemplo n.º 4
0
    def visitConstructorOrFunctionDefinition(
            self, ast: ConstructorOrFunctionDefinition):
        # Fixed point iteration
        size = 0
        leaves = ast.called_functions
        while len(ast.called_functions) > size:
            size = len(ast.called_functions)
            leaves = {
                fct: None
                for leaf in leaves for fct in leaf.called_functions
                if fct not in ast.called_functions
            }
            ast.called_functions.update(leaves)

        if ast in ast.called_functions:
            ast.is_recursive = True
            ast.has_static_body = False
Ejemplo n.º 5
0
    def visitConstructorOrFunctionDefinition(
            self, ast: ConstructorOrFunctionDefinition):
        if not ast.has_static_body:
            return

        for fct in ast.called_functions:
            if not fct.has_static_body:
                # This function (directly or indirectly) calls a recursive function
                ast.has_static_body = False
                return
Ejemplo n.º 6
0
class GlobalDefs:
    # gasleft: FunctionDefinition = FunctionDefinition(
    #     idf=Identifier('gasleft'),
    #     parameters=[],
    #     modifiers=[],
    #     return_parameters=[Parameter([], annotated_type=AnnotatedTypeName.uint_all(), idf=Identifier(''))],
    #     body=Block([])
    # )
    # gasleft.idf.parent = gasleft

    address_struct: StructDefinition = StructDefinition(
        Identifier('<address>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance'))
        ])
    set_parents(address_struct)

    address_payable_struct: StructDefinition = StructDefinition(
        Identifier('<address_payable>'), [
            VariableDeclaration([], AnnotatedTypeName.uint_all(),
                                Identifier('balance')),
            ConstructorOrFunctionDefinition(
                Identifier('send'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'],
                [Parameter([], AnnotatedTypeName.bool_all(), Identifier(''))],
                Block([])),
            ConstructorOrFunctionDefinition(
                Identifier('transfer'),
                [Parameter([], AnnotatedTypeName.uint_all(), Identifier(''))],
                ['public'], [], Block([])),
        ])
    address_payable_struct.members[1].can_be_private = False
    address_payable_struct.members[2].can_be_private = False
    set_parents(address_payable_struct)

    msg_struct: StructDefinition = StructDefinition(Identifier('<msg>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('sender')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('value')),
    ])
    set_parents(msg_struct)

    block_struct: StructDefinition = StructDefinition(Identifier('<block>'), [
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('coinbase')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('difficulty')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gaslimit')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('number')),
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('timestamp')),
    ])
    set_parents(block_struct)

    tx_struct: StructDefinition = StructDefinition(Identifier('<tx>'), [
        VariableDeclaration([], AnnotatedTypeName.uint_all(),
                            Identifier('gasprice')),
        VariableDeclaration([],
                            AnnotatedTypeName(TypeName.address_payable_type()),
                            Identifier('origin')),
    ])
    set_parents(tx_struct)
Ejemplo n.º 7
0
 def visitConstructorOrFunctionDefinition(
         self, ast: ConstructorOrFunctionDefinition):
     ast.namespace = ([] if ast.parent is None else
                      ast.parent.namespace) + [ast.idf]
Ejemplo n.º 8
0
    def evaluate_stmt_in_circuit(self, ast: Statement) -> AssignmentStatement:
        """
        Evaluate an entire statement privately.

        This works by turning the statement into an assignment statement where the

        * lhs is a tuple of all external locations (defined outside statement), which are modified inside the statement
        * rhs is the return value of an inlined function call expression to a virtual function where body = the statement + return statement \
          which returns a tuple of the most recent SSA version of all modified locations

        Note: Modifying external locations which are not owned by @me inside the statement is illegal (would leak information).
        Note: At the moment, this is only used for if statements with a private condition.

        :param ast: the statement to evaluate inside the circuit
        :return: AssignmentStatement as described above
        """
        astmt = ExpressionStatement(NumberLiteralExpr(0))
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                astmt = AssignmentStatement(None, None)
                break

        astmt.before_analysis = ast.before_analysis

        # External values written inside statement -> function return values
        ret_params = []
        for var in ast.modified_values:
            if var.in_scope_at(ast):
                # side effect affects location outside statement and has privacy @me
                assert ast.before_analysis.same_partition(
                    var.privacy, Expression.me_expr())
                assert isinstance(
                    var.target,
                    (Parameter, VariableDeclaration, StateVariableDeclaration))
                t = var.target.annotated_type.zkay_type
                if not t.type_name.is_primitive_type():
                    raise NotImplementedError(
                        'Reference types inside private if statements are not supported'
                    )
                ret_t = AnnotatedTypeName(t.type_name, Expression.me_expr(),
                                          t.homomorphism)  # t, but @me
                ret_param = IdentifierExpr(var.target.idf.clone(),
                                           ret_t).override(target=var.target)
                ret_param.statement = astmt
                ret_params.append(ret_param)

        # Build the imaginary function
        fdef = ConstructorOrFunctionDefinition(
            Identifier('<stmt_fct>'), [], ['private'], [
                Parameter([], ret.annotated_type, ret.target.idf)
                for ret in ret_params
            ], Block([ast, ReturnStatement(TupleExpr(ret_params))]))
        fdef.original_body = fdef.body
        fdef.body.parent = fdef
        fdef.parent = ast

        # inline "Call" to the imaginary function
        fcall = FunctionCallExpr(
            IdentifierExpr('<stmt_fct>').override(target=fdef), [])
        fcall.statement = astmt
        ret_args = self.inline_function_call_into_circuit(fcall)

        # Move all return values out of the circuit
        if not isinstance(ret_args, TupleExpr):
            ret_args = TupleExpr([ret_args])
        for ret_arg in ret_args.elements:
            ret_arg.statement = astmt
        ret_arg_outs = [
            self._get_circuit_output_for_private_expression(
                ret_arg, Expression.me_expr(),
                ret_param.annotated_type.homomorphism)
            for ret_param, ret_arg in zip(ret_params, ret_args.elements)
        ]

        # Create assignment statement
        if ret_params:
            astmt.lhs = TupleExpr(
                [ret_param.clone() for ret_param in ret_params])
            astmt.rhs = TupleExpr(ret_arg_outs)
            return astmt
        else:
            assert isinstance(astmt, ExpressionStatement)
            return astmt
Ejemplo n.º 9
0
    def split_into_external_and_internal_fct(self, f: ConstructorOrFunctionDefinition, original_params: List[Parameter],
                                             global_owners: List[PrivacyLabelExpr]) -> Tuple[ConstructorOrFunctionDefinition,
                                                                                             ConstructorOrFunctionDefinition]:
        """
        Take public function f and split it into an internal function and an external wrapper function.

        :param f: [SIDE EFFECT] function to split (at least requires_verification_if_external)
        :param original_params: list of transformed function parameters without additional parameters added due to transformation
        :param global_owners: list of static labels (me + final address state variable identifiers)
        :return: Tuple of newly created external and internal function definitions
        """
        assert f.requires_verification_when_external

        # Create new empty function with same parameters as original -> external wrapper
        if f.is_function:
            new_modifiers = ['external']
            original_params = [deep_copy(p, with_types=True).with_changed_storage('memory', 'calldata') for p in original_params]
        else:
            new_modifiers = ['public']
        if f.is_payable:
            new_modifiers.append('payable')

        requires_proof = True
        if not f.has_side_effects:
            requires_proof = False
            new_modifiers.append('view')
        new_f = ConstructorOrFunctionDefinition(f.idf, original_params, new_modifiers, f.return_parameters, Block([]))

        # Make original function internal
        f.idf = Identifier(cfg.get_internal_name(f))
        f.modifiers = ['internal' if mod == 'public' else mod for mod in f.modifiers if mod != 'payable']
        f.requires_verification_when_external = False

        # Create new circuit for external function
        circuit = self.create_circuit_helper(new_f, global_owners, self.circuits[f])
        if not f.requires_verification:
            del self.circuits[f]
        self.circuits[new_f] = circuit

        # Set meta attributes and populate body
        new_f.requires_verification = True
        new_f.requires_verification_when_external = True
        new_f.called_functions = f.called_functions
        new_f.called_functions[f] = None
        new_f.used_crypto_backends = f.used_crypto_backends
        new_f.body = self.create_external_wrapper_body(f, circuit, original_params, requires_proof)

        # Add out and proof parameter to external wrapper
        storage_loc = 'calldata' if new_f.is_function else 'memory'
        new_f.add_param(Array(AnnotatedTypeName.uint_all()), Identifier(cfg.zk_out_name), storage_loc)

        if requires_proof:
            new_f.add_param(AnnotatedTypeName.proof_type(), Identifier(cfg.proof_param_name), storage_loc)

        return new_f, f
Ejemplo n.º 10
0
    def create_internal_verification_wrapper(self, ast: ConstructorOrFunctionDefinition):
        """
        Add the necessary additional parameters and boiler plate code for verification support to the given function.

        :param ast: [SIDE EFFECT] Internal function which requires verification
        """
        circuit = self.circuits[ast]
        stmts = []

        symmetric_cipher_used = any([backend.is_symmetric_cipher() for backend in ast.used_crypto_backends])
        if symmetric_cipher_used and 'pure' in ast.modifiers:
            # Symmetric trafo requires msg.sender access -> change from pure to view
            ast.modifiers = ['view' if mod == 'pure' else mod for mod in ast.modifiers]

        # Add additional params
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_in_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_in_name}_start_idx')
        ast.add_param(Array(AnnotatedTypeName.uint_all()), cfg.zk_out_name)
        ast.add_param(AnnotatedTypeName.uint_all(), f'{cfg.zk_out_name}_start_idx')

        # Verify that in/out parameters have correct size
        out_start_idx, in_start_idx = IdentifierExpr(f'{cfg.zk_out_name}_start_idx'), IdentifierExpr(f'{cfg.zk_in_name}_start_idx')
        out_var, in_var = IdentifierExpr(cfg.zk_out_name), IdentifierExpr(cfg.zk_in_name).as_type(Array(AnnotatedTypeName.uint_all()))
        stmts.append(RequireStatement(out_start_idx.binop('+', NumberLiteralExpr(circuit.out_size_trans)).binop('<=', out_var.dot('length'))))
        stmts.append(RequireStatement(in_start_idx.binop('+', NumberLiteralExpr(circuit.in_size_trans)).binop('<=', in_var.dot('length'))))

        # Declare zk_data struct var (if needed)
        if circuit.requires_zk_data_struct():
            zk_struct_type = StructTypeName([Identifier(circuit.zk_data_struct_name)])
            stmts += [Identifier(cfg.zk_data_var_name).decl_var(zk_struct_type), BlankLine()]

        # Declare return variable if necessary
        if ast.return_parameters:
            stmts += Comment.comment_list("Declare return variables", [VariableDeclarationStatement(vd) for vd in ast.return_var_decls])

        # Find all me-keys in the in array
        me_key_idx: Dict[CryptoParams, int] = {}
        offset = 0
        for (key_owner, crypto_params) in circuit.requested_global_keys:
            if key_owner == MeExpr():
                assert crypto_params not in me_key_idx
                me_key_idx[crypto_params] = offset
            offset += crypto_params.key_len

        # Deserialize out array (if any)
        deserialize_stmts = []
        offset = 0
        for s in circuit.output_idfs:
            deserialize_stmts.append(s.deserialize(cfg.zk_out_name, out_start_idx, offset))
            if isinstance(s.t, CipherText) and s.t.crypto_params.is_symmetric_cipher():
                # Assign sender field to user-encrypted values if necessary
                # Assumption: s.t.crypto_params.key_len == 1 for all symmetric ciphers
                assert s.t.crypto_params in me_key_idx, "Symmetric cipher but did not request me key"
                key_idx = me_key_idx[s.t.crypto_params]
                sender_key = in_var.index(key_idx)
                cipher_payload_len = s.t.crypto_params.cipher_payload_len
                deserialize_stmts.append(s.get_loc_expr().index(cipher_payload_len).assign(sender_key))
            offset += s.t.size_in_uints
        if deserialize_stmts:
            stmts.append(StatementList(Comment.comment_wrap_block("Deserialize output values", deserialize_stmts), excluded_from_simulation=True))

        # Include original transformed function body
        stmts += ast.body.statements

        # Serialize in parameters to in array (if any)
        serialize_stmts = []
        offset = 0
        for s in circuit.input_idfs:
            serialize_stmts += [s.serialize(cfg.zk_in_name, in_start_idx, offset)]
            offset += s.t.size_in_uints
        if offset:
            stmts.append(Comment())
            stmts += Comment.comment_wrap_block('Serialize input values', serialize_stmts)

        # Add return statement at the end if necessary
        # (was previously replaced by assignment to return_var by ZkayStatementTransformer)
        if circuit.has_return_var:
            stmts.append(ReturnStatement(TupleExpr([IdentifierExpr(vd.idf.clone()).override(target=vd) for vd in ast.return_var_decls])))

        ast.body.statements[:] = stmts
Ejemplo n.º 11
0
 def visitConstructorOrFunctionDefinition(self, ast: ConstructorOrFunctionDefinition):
     ast.names = {p.idf.name: p.idf for p in ast.parameters}