def find_or_create_data_key(self):

        key_vault_client = MongoClient(self.connection_string)

        key_vault = key_vault_client[self.key_db][self.key_coll]

        self.ensure_unique_index_on_key_vault(key_vault)

        data_key = key_vault.find_one({"keyAltNames": self.key_alt_name})

        self.client_encryption = ClientEncryption(
            self.kms_providers, self.key_vault_namespace, key_vault_client,
            CodecOptions(uuid_representation=STANDARD))

        if data_key is None:
            data_key = self.client_encryption.create_data_key(
                "local", key_alt_names=[self.key_alt_name])
            uuid_data_key_id = UUID(bytes=data_key)

        else:
            uuid_data_key_id = data_key["_id"]

        base_64_data_key_id = (base64.b64encode(
            uuid_data_key_id.bytes).decode("utf-8"))

        return uuid_data_key_id, base_64_data_key_id
    def test_bson_errors(self):
        client_encryption = ClientEncryption(KMS_PROVIDERS, 'admin.datakeys',
                                             client_context.client, OPTS)
        self.addCleanup(client_encryption.close)

        # Attempt to encrypt an unencodable object.
        unencodable_value = object()
        with self.assertRaises(BSONError):
            client_encryption.encrypt(
                unencodable_value,
                Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
                key_id=Binary(uuid.uuid4().bytes, UUID_SUBTYPE))
示例#3
0
    def create_data_encryption_key(self, kms_providers):
        # create data encryption key and store in DB
        client_encryption = ClientEncryption(
            # pass in the kms_providers variable from the previous step
            kms_providers,
            self.key_vault_namespace,
            self.client,
            CodecOptions(uuid_representation=STANDARD))
        data_key_id = client_encryption.create_data_key("local")
        uuid_data_key_id = UUID(bytes=data_key_id)
        print(f'data key created using KMS provider: {uuid_data_key_id} \n')

        base_64_data_key_id = base64.b64encode(data_key_id)
        return data_key_id
示例#4
0
    def find_or_create_data_key(self):
        """
        In the guide, https://docs.mongodb.com/drivers/security/client-side-field-level-encryption-guide/,
        we create the data key and then show that it is created by
        using a find_one query. Here, in implementation, we only create the key if
        it doesn't already exist, ensuring we only have one local data key.

        We also use the key_alt_names field and provide a key alt name to aid in
        finding the key in the clients.py script.
        """

        key_vault_client = MongoClient(self.connection_string)

        key_vault = key_vault_client[self.key_db][self.key_coll]

        self.ensure_unique_index_on_key_vault(key_vault)

        data_key = key_vault.find_one({"keyAltNames": self.key_alt_name})

        # create a key
        if data_key is None:
            with ClientEncryption(
                    self.kms_provider, self.key_vault_namespace,
                    key_vault_client, CodecOptions(
                        uuid_representation=STANDARD)) as client_encryption:

                # create data key using KMS master key
                return client_encryption.create_data_key(
                    self.kms_provider_name,
                    key_alt_names=[self.key_alt_name],
                    master_key=self.master_key)

        return data_key['_id'].bytes
示例#5
0
 def test_with_statement(self):
     with ClientEncryption(KMS_PROVIDERS, 'keyvault.datakeys',
                           client_context.client,
                           OPTS) as client_encryption:
         pass
     with self.assertRaisesRegex(InvalidOperation,
                                 'Cannot use closed ClientEncryption'):
         client_encryption.create_data_key('local')
    def test_validation(self):
        client_encryption = ClientEncryption(KMS_PROVIDERS, 'admin.datakeys',
                                             client_context.client, OPTS)
        self.addCleanup(client_encryption.close)

        msg = 'value to decrypt must be a bson.binary.Binary with subtype 6'
        with self.assertRaisesRegex(TypeError, msg):
            client_encryption.decrypt('str')
        with self.assertRaisesRegex(TypeError, msg):
            client_encryption.decrypt(Binary(b'123'))

        msg = 'key_id must be a bson.binary.Binary with subtype 4'
        algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
        with self.assertRaisesRegex(TypeError, msg):
            client_encryption.encrypt('str', algo, key_id=uuid.uuid4())
        with self.assertRaisesRegex(TypeError, msg):
            client_encryption.encrypt('str', algo, key_id=Binary(b'123'))
示例#7
0
def encrypt_dict(encryption_client: ClientEncryption, val, filter_list=None):
    """
    Encrypt fields of a dict, possibly filtered by an array as kwarg "filter"
    """
    return recursive_replace(
        val,
        lambda v: encryption_client.encrypt(
            v,
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_alt_name=ENCRYPTION_KEY_NAME,
        ),
        filter_list,
    )
示例#8
0
def create_json_schema_file(kms_providers, key_vault_namespace,
                            key_vault_client):
    client_encryption = ClientEncryption(
        kms_providers,
        key_vault_namespace,
        key_vault_client,
        # The CodecOptions class used for encrypting and decrypting.
        # This should be the same CodecOptions instance you have configured
        # on MongoClient, Database, or Collection. We will not be calling
        # encrypt() or decrypt() in this example so we can use any
        # CodecOptions.
        CodecOptions())

    # Create a new data key and json schema for the encryptedField.
    # https://dochub.mongodb.org/core/client-side-field-level-encryption-automatic-encryption-rules
    data_key_id = client_encryption.create_data_key(
        'local', key_alt_names=['pymongo_encryption_example_1'])
    schema = {
        "properties": {
            "encryptedField": {
                "encrypt": {
                    "keyId": [data_key_id],
                    "bsonType":
                    "string",
                    "algorithm":
                    Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
                }
            }
        },
        "bsonType": "object"
    }
    # Use CANONICAL_JSON_OPTIONS so that other drivers and tools will be
    # able to parse the MongoDB extended JSON file.
    json_schema_string = json_util.dumps(
        schema, json_options=json_util.CANONICAL_JSON_OPTIONS)

    with open('jsonSchema.json', 'w') as file:
        file.write(json_schema_string)
示例#9
0
    def _test_external_key_vault(self, with_external_key_vault):
        self.client.db.coll.drop()
        vault = create_key_vault(self.client.keyvault.datakeys,
                                 json_data('corpus', 'corpus-key-local.json'),
                                 json_data('corpus', 'corpus-key-aws.json'))
        self.addCleanup(vault.drop)

        # Configure the encrypted field via the local schema_map option.
        schemas = {'db.coll': json_data('external', 'external-schema.json')}
        if with_external_key_vault:
            key_vault_client = rs_or_single_client(username='******',
                                                   password='******')
            self.addCleanup(key_vault_client.close)
        else:
            key_vault_client = client_context.client
        opts = AutoEncryptionOpts(self.kms_providers(),
                                  'keyvault.datakeys',
                                  schema_map=schemas,
                                  key_vault_client=key_vault_client)

        client_encrypted = rs_or_single_client(auto_encryption_opts=opts,
                                               uuidRepresentation='standard')
        self.addCleanup(client_encrypted.close)

        client_encryption = ClientEncryption(self.kms_providers(),
                                             'keyvault.datakeys',
                                             key_vault_client, OPTS)
        self.addCleanup(client_encryption.close)

        if with_external_key_vault:
            # Authentication error.
            with self.assertRaises(EncryptionError) as ctx:
                client_encrypted.db.coll.insert_one({"encrypted": "test"})
            # AuthenticationFailed error.
            self.assertIsInstance(ctx.exception.cause, OperationFailure)
            self.assertEqual(ctx.exception.cause.code, 18)
        else:
            client_encrypted.db.coll.insert_one({"encrypted": "test"})

        if with_external_key_vault:
            # Authentication error.
            with self.assertRaises(EncryptionError) as ctx:
                client_encryption.encrypt(
                    "test",
                    Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
                    key_id=LOCAL_KEY_ID)
            # AuthenticationFailed error.
            self.assertIsInstance(ctx.exception.cause, OperationFailure)
            self.assertEqual(ctx.exception.cause.code, 18)
        else:
            client_encryption.encrypt(
                "test",
                Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
                key_id=LOCAL_KEY_ID)
示例#10
0
    def find_or_create_data_key(self):
        """
        In the guide, https://docs.mongodb.com/ecosystem/use-cases/client-side-field-level-encryption-guide/,
        we create the data key and then show that it is created by
        using a find_one query. Here, in implementation, we only create the key if
        it doesn't already exist, ensuring we only have one local data key.

        We also use the key_alt_names field and provide a key alt name to aid in
        finding the key in the clients.py script.
        """

        key_vault_client = MongoClient(self.connection_string)

        key_vault = key_vault_client[self.key_db][self.key_coll]

        self.ensure_unique_index_on_key_vault(key_vault)

        data_key = key_vault.find_one(
            {"keyAltNames": self.key_alt_name}
        )

        if data_key is None:
            with ClientEncryption(self.kms_providers,
                                  self.key_vault_namespace,
                                  key_vault_client,
                                  CodecOptions(uuid_representation=STANDARD)
                                  ) as client_encryption:

                data_key = client_encryption.create_data_key(
                    "local", key_alt_names=[self.key_alt_name])
                uuid_data_key_id = UUID(bytes=data_key)

        else:
            uuid_data_key_id = data_key["_id"]

        base_64_data_key_id = (base64
                               .b64encode(uuid_data_key_id.bytes)
                               .decode("utf-8"))

        return base_64_data_key_id
示例#11
0
    def test_encrypt_decrypt(self):
        client_encryption = ClientEncryption(KMS_PROVIDERS,
                                             'keyvault.datakeys',
                                             client_context.client, OPTS)
        self.addCleanup(client_encryption.close)
        # Use standard UUID representation.
        key_vault = client_context.client.keyvault.get_collection(
            'datakeys', codec_options=OPTS)
        self.addCleanup(key_vault.drop)

        # Create the encrypted field's data key.
        key_id = client_encryption.create_data_key('local',
                                                   key_alt_names=['name'])
        self.assertBinaryUUID(key_id)
        self.assertTrue(key_vault.find_one({'_id': key_id}))

        # Create an unused data key to make sure filtering works.
        unused_key_id = client_encryption.create_data_key(
            'local', key_alt_names=['unused'])
        self.assertBinaryUUID(unused_key_id)
        self.assertTrue(key_vault.find_one({'_id': unused_key_id}))

        doc = {'_id': 0, 'ssn': '000'}
        encrypted_ssn = client_encryption.encrypt(
            doc['ssn'],
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_id=key_id)

        # Ensure encryption via key_alt_name for the same key produces the
        # same output.
        encrypted_ssn2 = client_encryption.encrypt(
            doc['ssn'],
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_alt_name='name')
        self.assertEqual(encrypted_ssn, encrypted_ssn2)

        # Test decryption.
        decrypted_ssn = client_encryption.decrypt(encrypted_ssn)
        self.assertEqual(decrypted_ssn, doc['ssn'])
示例#12
0
    def test_data_key(self):
        listener = OvertCommandListener()
        client = rs_or_single_client(event_listeners=[listener])
        self.addCleanup(client.close)
        client.db.coll.drop()
        vault = create_key_vault(client.keyvault.datakeys)
        self.addCleanup(vault.drop)

        # Configure the encrypted field via the local schema_map option.
        schemas = {
            "db.coll": {
                "bsonType": "object",
                "properties": {
                    "encrypted_placeholder": {
                        "encrypt": {
                            "keyId": "/placeholder",
                            "bsonType": "string",
                            "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Random"
                        }
                    }
                }
            }
        }
        opts = AutoEncryptionOpts(self.kms_providers(),
                                  'keyvault.datakeys',
                                  schema_map=schemas)
        client_encrypted = rs_or_single_client(auto_encryption_opts=opts,
                                               uuidRepresentation='standard')
        self.addCleanup(client_encrypted.close)

        client_encryption = ClientEncryption(self.kms_providers(),
                                             'keyvault.datakeys', client, OPTS)
        self.addCleanup(client_encryption.close)

        # Local create data key.
        listener.reset()
        local_datakey_id = client_encryption.create_data_key(
            'local', key_alt_names=['local_altname'])
        self.assertBinaryUUID(local_datakey_id)
        cmd = listener.results['started'][-1]
        self.assertEqual('insert', cmd.command_name)
        self.assertEqual({'w': 'majority'}, cmd.command.get('writeConcern'))
        docs = list(vault.find({'_id': local_datakey_id}))
        self.assertEqual(len(docs), 1)
        self.assertEqual(docs[0]['masterKey']['provider'], 'local')

        # Local encrypt by key_id.
        local_encrypted = client_encryption.encrypt(
            'hello local',
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_id=local_datakey_id)
        self.assertEncrypted(local_encrypted)
        client_encrypted.db.coll.insert_one({
            '_id': 'local',
            'value': local_encrypted
        })
        doc_decrypted = client_encrypted.db.coll.find_one({'_id': 'local'})
        self.assertEqual(doc_decrypted['value'], 'hello local')

        # Local encrypt by key_alt_name.
        local_encrypted_altname = client_encryption.encrypt(
            'hello local',
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_alt_name='local_altname')
        self.assertEqual(local_encrypted_altname, local_encrypted)

        # AWS create data key.
        listener.reset()
        master_key = {
            'region':
            'us-east-1',
            'key':
            'arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-'
            '9f25-e30687b580d0'
        }
        aws_datakey_id = client_encryption.create_data_key(
            'aws', master_key=master_key, key_alt_names=['aws_altname'])
        self.assertBinaryUUID(aws_datakey_id)
        cmd = listener.results['started'][-1]
        self.assertEqual('insert', cmd.command_name)
        self.assertEqual({'w': 'majority'}, cmd.command.get('writeConcern'))
        docs = list(vault.find({'_id': aws_datakey_id}))
        self.assertEqual(len(docs), 1)
        self.assertEqual(docs[0]['masterKey']['provider'], 'aws')

        # AWS encrypt by key_id.
        aws_encrypted = client_encryption.encrypt(
            'hello aws',
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_id=aws_datakey_id)
        self.assertEncrypted(aws_encrypted)
        client_encrypted.db.coll.insert_one({
            '_id': 'aws',
            'value': aws_encrypted
        })
        doc_decrypted = client_encrypted.db.coll.find_one({'_id': 'aws'})
        self.assertEqual(doc_decrypted['value'], 'hello aws')

        # AWS encrypt by key_alt_name.
        aws_encrypted_altname = client_encryption.encrypt(
            'hello aws',
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_alt_name='aws_altname')
        self.assertEqual(aws_encrypted_altname, aws_encrypted)

        # Explicitly encrypting an auto encrypted field.
        msg = (r'Cannot encrypt element of type binData because schema '
               r'requires that type is one of: \[ string \]')
        with self.assertRaisesRegex(EncryptionError, msg):
            client_encrypted.db.coll.insert_one(
                {'encrypted_placeholder': local_encrypted})
示例#13
0
 def test_close(self):
     client_encryption = ClientEncryption(KMS_PROVIDERS,
                                          'keyvault.datakeys',
                                          client_context.client, OPTS)
     client_encryption.close()
     # Close can be called multiple times.
     client_encryption.close()
     algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
     msg = 'Cannot use closed ClientEncryption'
     with self.assertRaisesRegex(InvalidOperation, msg):
         client_encryption.create_data_key('local')
     with self.assertRaisesRegex(InvalidOperation, msg):
         client_encryption.encrypt('val', algo, key_alt_name='name')
     with self.assertRaisesRegex(InvalidOperation, msg):
         client_encryption.decrypt(Binary(b'', 6))
示例#14
0
    def test_codec_options(self):
        with self.assertRaisesRegex(TypeError, 'codec_options must be'):
            ClientEncryption(KMS_PROVIDERS, 'keyvault.datakeys',
                             client_context.client, None)

        opts = CodecOptions(uuid_representation=JAVA_LEGACY)
        client_encryption_legacy = ClientEncryption(KMS_PROVIDERS,
                                                    'keyvault.datakeys',
                                                    client_context.client,
                                                    opts)
        self.addCleanup(client_encryption_legacy.close)

        # Create the encrypted field's data key.
        key_id = client_encryption_legacy.create_data_key('local')

        # Encrypt a UUID with JAVA_LEGACY codec options.
        value = uuid.uuid4()
        encrypted_legacy = client_encryption_legacy.encrypt(
            value,
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_id=key_id)
        decrypted_value_legacy = client_encryption_legacy.decrypt(
            encrypted_legacy)
        self.assertEqual(decrypted_value_legacy, value)

        # Encrypt the same UUID with STANDARD codec options.
        client_encryption = ClientEncryption(KMS_PROVIDERS,
                                             'keyvault.datakeys',
                                             client_context.client, OPTS)
        self.addCleanup(client_encryption.close)
        encrypted_standard = client_encryption.encrypt(
            value,
            Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
            key_id=key_id)
        decrypted_standard = client_encryption.decrypt(encrypted_standard)
        self.assertEqual(decrypted_standard, value)

        # Test that codec_options is applied during encryption.
        self.assertNotEqual(encrypted_standard, encrypted_legacy)
        # Test that codec_options is applied during decryption.
        self.assertEqual(client_encryption_legacy.decrypt(encrypted_standard),
                         value)
        self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value)
示例#15
0
 def setUpClass(cls):
     super(TestCustomEndpoint, cls).setUpClass()
     cls.client_encryption = ClientEncryption({'aws': AWS_CREDS},
                                              'keyvault.datakeys',
                                              client_context.client, OPTS)
示例#16
0
from pymongo import MongoClient
from pymongo.encryption import (Algorithm, ClientEncryption)

from .db import client, db


local_master_key = os.urandom(96)
kms_providers = {"local": {"key": local_master_key}}

key_vault_namespace = "mydb.__keyVault"
key_vault_db, key_vault_coll = key_vault_namespace.split('.', 1)

key_vault = client[key_vault_db][key_vault_coll]
key_vault.drop()
key_vault.create_index(
    "keyAltNames",
    unique=True,
    partialFilterExpression={"keyAltNames": {"$exists": True}}
)

client_encryption = ClientEncryption(
    kms_providers,
    key_vault_namespace,
    client,
    db.coll.codec_options
)

data_key_id = client_encryption.create_data_key(
    'local', key_alt_names=["trainee_mi"]
)
示例#17
0
def get_connection(with_enc=False):
    """
    Returns instance of global pooled connection.
    The connection instance has automatic decryption enabled.
    -> MongoClient
    If with_enc=True this returns a ClientEncryption used for encryption fields along the connection
    -> MonogClient, ClientEncryption
    """
    global CONNECTION
    global CLIENT_ENC
    if CONNECTION is not None and CLIENT_ENC is not None:
        if with_enc:
            return CONNECTION, CLIENT_ENC
        else:
            return CONNECTION

    else:
        # Key must be 96 bytes
        local_master_key_raw = os.environ["SOFI_BIFROST_ENCRYPTION_KEY"]
        local_master_key = binascii.a2b_base64(local_master_key_raw.encode())

        kms_providers = {"local": {"key": local_master_key}}
        # The MongoDB namespace (db.collection) used to store
        # the encryption data keys.
        key_vault_namespace = SOFI_BIFROST_ENCRYPTION_NAMESPACE
        key_vault_db_name, key_name = (
            ENCRYPTION_DB,
            ENCRYPTION_KEY_NAME,
        )

        # bypass_auto_encryption=True disable automatic encryption but keeps
        # the automatic _decryption_ behavior. bypass_auto_encryption will
        # also disable spawning mongocryptd.
        auto_encryption_opts = AutoEncryptionOpts(kms_providers,
                                                  key_vault_namespace,
                                                  bypass_auto_encryption=True)

        client = (MongoClient(
            auto_encryption_opts=auto_encryption_opts) if DEBUG else
                  MongoClient(BIFROST_MONGO_CONN,
                              auto_encryption_opts=auto_encryption_opts))

        if not client.is_primary:
            logging.debug(
                "MongoDB client is not primary - getting the primary client")
            client = client.primary

        coll = client.test.coll

        # First time key setup. Index creation in mongo is idempotent
        key_vault = client[key_vault_db_name][key_name]
        key_vault.create_index(
            "keyAltNames",
            unique=True,
            partialFilterExpression={"keyAltNames": {
                "$exists": True
            }},
        )

        client_encryption = ClientEncryption(
            kms_providers,
            key_vault_namespace,
            # The MongoClient to use for reading/writing to the key vault.
            # This can be the same MongoClient used by the main application.
            client,
            # The CodecOptions class used for encrypting and decrypting.
            # This should be the same CodecOptions instance you have configured
            # on MongoClient, Database, or Collection.
            coll.codec_options,
        )

        existing = client[key_vault_db_name][key_name].find_one()
        if existing is None:
            client_encryption.create_data_key("local",
                                              key_alt_names=[key_name])

        CONNECTION = client
        CLIENT_ENC = client_encryption
        if with_enc:
            return CONNECTION, CLIENT_ENC
        else:
            return CONNECTION
示例#18
0
import os

from pymongo import MongoClient
from pymongo.encryption import ClientEncryption
from bson import binary
from bson.codec_options import CodecOptions

codec_opts = CodecOptions(uuid_representation=binary.STANDARD)

# Test key material generated by: echo $(head -c 96 /dev/urandom | base64 | tr -d '\n')
if "LOCAL_MASTER_KEY" not in os.environ:
    raise Exception("Set LOCAL_MASTER_KEY env variable to 96 bytes of base64")

master_key = binary.Binary(b64decode(os.environ["LOCAL_MASTER_KEY"]))

# Reset the collection
key_vault_client = MongoClient("mongodb://localhost/")
key_vault_client.lab7.key_vault.drop()

# Configure a ClientEncryption object to create data keys
kms_providers = {"local": {"key": master_key}}
key_vault_client = MongoClient()
client_encryption = ClientEncryption(
    kms_providers, "lab7.key_vault", key_vault_client, codec_opts)

key_uuid = client_encryption.create_data_key("local")

# Store the key id into a file for easy access
open("key_uuid.txt", "w").write(b64encode(key_uuid).decode("utf-8"))
print("Created data key in lab7.key_vault with UUID: %s" % key_uuid.hex())
示例#19
0
    def _test_corpus(self, opts):
        # Drop and create the collection 'db.coll' with jsonSchema.
        coll = create_with_schema(
            self.client.db.coll,
            self.fix_up_schema(json_data('corpus', 'corpus-schema.json')))
        self.addCleanup(coll.drop)

        vault = create_key_vault(self.client.keyvault.datakeys,
                                 json_data('corpus', 'corpus-key-local.json'),
                                 json_data('corpus', 'corpus-key-aws.json'))
        self.addCleanup(vault.drop)

        client_encrypted = rs_or_single_client(auto_encryption_opts=opts,
                                               uuidRepresentation='standard')
        self.addCleanup(client_encrypted.close)

        client_encryption = ClientEncryption(self.kms_providers(),
                                             'keyvault.datakeys',
                                             client_context.client, OPTS)
        self.addCleanup(client_encryption.close)

        corpus = self.fix_up_curpus(json_data('corpus', 'corpus.json'))
        corpus_copied = SON()
        for key, value in corpus.items():
            corpus_copied[key] = copy.deepcopy(value)
            if key in ('_id', 'altname_aws', 'altname_local'):
                continue
            if value['method'] == 'auto':
                continue
            if value['method'] == 'explicit':
                identifier = value['identifier']
                self.assertIn(identifier, ('id', 'altname'))
                kms = value['kms']
                self.assertIn(kms, ('local', 'aws'))
                if identifier == 'id':
                    if kms == 'local':
                        kwargs = dict(key_id=LOCAL_KEY_ID)
                    else:
                        kwargs = dict(key_id=AWS_KEY_ID)
                else:
                    kwargs = dict(key_alt_name=kms)

                self.assertIn(value['algo'], ('det', 'rand'))
                if value['algo'] == 'det':
                    algo = (
                        Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic)
                else:
                    algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random

                try:
                    encrypted_val = client_encryption.encrypt(
                        value['value'], algo, **kwargs)
                    if not value['allowed']:
                        self.fail('encrypt should have failed: %r: %r' %
                                  (key, value))
                    corpus_copied[key]['value'] = encrypted_val
                except Exception:
                    if value['allowed']:
                        tb = traceback.format_exc()
                        self.fail('encrypt failed: %r: %r, traceback: %s' %
                                  (key, value, tb))

        client_encrypted.db.coll.insert_one(corpus_copied)
        corpus_decrypted = client_encrypted.db.coll.find_one()
        self.assertEqual(corpus_decrypted, corpus)

        corpus_encrypted_expected = self.fix_up_curpus_encrypted(
            json_data('corpus', 'corpus-encrypted.json'), corpus)
        corpus_encrypted_actual = coll.find_one()
        for key, value in corpus_encrypted_actual.items():
            if key in ('_id', 'altname_aws', 'altname_local'):
                continue

            if value['algo'] == 'det':
                self.assertEqual(value['value'],
                                 corpus_encrypted_expected[key]['value'], key)
            elif value['algo'] == 'rand' and value['allowed']:
                self.assertNotEqual(value['value'],
                                    corpus_encrypted_expected[key]['value'],
                                    key)

            if value['allowed']:
                decrypt_actual = client_encryption.decrypt(value['value'])
                decrypt_expected = client_encryption.decrypt(
                    corpus_encrypted_expected[key]['value'])
                self.assertEqual(decrypt_actual, decrypt_expected, key)
            else:
                self.assertEqual(value['value'], corpus[key]['value'], key)
class CsfleHelper:
    def __init__(self,
                 kms_providers=None,
                 key_db="encryption",
                 key_coll="__keyVault",
                 key_alt_name=None,
                 schema=None,
                 connection_string=None,
                 mongocryptd_bypass_spawn=False,
                 mongocryptd_spawn_path="mongocryptd"):
        super().__init__()
        if kms_providers is None:
            raise ValueError("kms_provider is required")
        self.kms_providers = kms_providers
        self.key_alt_name = key_alt_name
        self.key_db = key_db
        self.key_coll = key_coll
        self.key_vault_namespace = f"{self.key_db}.{self.key_coll}"
        self.schema = schema
        self.client_encryption = None
        self.connection_string = connection_string
        self.mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
        self.mongocryptd_spawn_path = mongocryptd_spawn_path

    def ensure_unique_index_on_key_vault(self, key_vault):
        key_vault.create_index(
            "keyAltNames",
            unique=True,
            partialFilterExpression={"keyAltNames": {
                "$exists": True
            }})

    def find_or_create_data_key(self):

        key_vault_client = MongoClient(self.connection_string)

        key_vault = key_vault_client[self.key_db][self.key_coll]

        self.ensure_unique_index_on_key_vault(key_vault)

        data_key = key_vault.find_one({"keyAltNames": self.key_alt_name})

        self.client_encryption = ClientEncryption(
            self.kms_providers, self.key_vault_namespace, key_vault_client,
            CodecOptions(uuid_representation=STANDARD))

        if data_key is None:
            data_key = self.client_encryption.create_data_key(
                "local", key_alt_names=[self.key_alt_name])
            uuid_data_key_id = UUID(bytes=data_key)

        else:
            uuid_data_key_id = data_key["_id"]

        base_64_data_key_id = (base64.b64encode(
            uuid_data_key_id.bytes).decode("utf-8"))

        return uuid_data_key_id, base_64_data_key_id

    def get_regular_client(self):
        return MongoClient(self.connection_string)

    def get_csfle_enabled_client(self, schema):
        return MongoClient(
            self.connection_string,
            auto_encryption_opts=AutoEncryptionOpts(
                self.kms_providers,
                self.key_vault_namespace,
                mongocryptd_bypass_spawn=self.mongocryptd_bypass_spawn,
                mongocryptd_spawn_path=self.mongocryptd_spawn_path,
                bypass_auto_encryption=True,
                schema_map=schema),
            connect=False)

    @staticmethod
    def create_json_schema(data_key):
        return {
            'bsonType': 'object',
            'encryptMetadata': {
                'keyId': [Binary(base64.b64decode(data_key), 4)]
            },
            'properties': {
                'email': {
                    'encrypt': {
                        'bsonType': "string",
                        'algorithm':
                        "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
                    }
                },
                'password': {
                    'encrypt': {
                        'bsonType': "string",
                        'algorithm': "AEAD_AES_256_CBC_HMAC_SHA_512-Random"
                    }
                },
                'login': {
                    'encrypt': {
                        'bsonType': "string",
                        'algorithm':
                        "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
                    }
                },
                'logindata': {
                    'bsonType': "object",
                    'properties': {
                        'password': {
                            'encrypt': {
                                'bsonType': "string",
                                'algorithm':
                                "AEAD_AES_256_CBC_HMAC_SHA_512-Random"
                            }
                        },
                        'login': {
                            'encrypt': {
                                'bsonType': "string",
                                'algorithm':
                                "AEAD_AES_256_CBC_HMAC_SHA_512-Random"
                            }
                        }
                    }
                }
            }
        }