def test_encrypt_decrypt_two_raw_keys(self):
    primitive1, raw_key1 = self.new_primitive_key_pair(1234, tink_pb2.RAW)
    primitive2, raw_key2 = self.new_primitive_key_pair(5678, tink_pb2.RAW)
    raw_ciphertext1 = primitive1.encrypt_deterministically(
        b'plaintext1', b'associated_data1')
    raw_ciphertext2 = primitive2.encrypt_deterministically(
        b'plaintext2', b'associated_data2')

    pset = primitive_set.new_primitive_set(deterministic_aead.DeterministicAead)
    pset.add_primitive(primitive1, raw_key1)
    pset.set_primary(pset.add_primitive(primitive2, raw_key2))
    wrapped_daead = deterministic_aead_wrapper.DeterministicAeadWrapper().wrap(
        pset)

    self.assertEqual(
        wrapped_daead.decrypt_deterministically(raw_ciphertext1,
                                                b'associated_data1'),
        b'plaintext1')
    self.assertEqual(
        wrapped_daead.decrypt_deterministically(raw_ciphertext2,
                                                b'associated_data2'),
        b'plaintext2')
    self.assertEqual(
        wrapped_daead.decrypt_deterministically(
            wrapped_daead.encrypt_deterministically(b'plaintext',
                                                    b'associated_data'),
            b'associated_data'), b'plaintext')
    def test_encrypt_decrypt_with_key_rotation_from_raw(self):
        primitive, raw_key = self.new_primitive_key_pair(1234, tink_pb2.RAW)
        old_raw_ciphertext = primitive.encrypt_deterministically(
            b'plaintext', b'associated_data')

        pset = primitive_set.new_primitive_set(
            deterministic_aead.DeterministicAead)
        pset.add_primitive(primitive, raw_key)
        new_primitive, new_key = self.new_primitive_key_pair(
            5678, tink_pb2.TINK)
        new_entry = pset.add_primitive(new_primitive, new_key)
        pset.set_primary(new_entry)
        wrapped_daead = deterministic_aead_wrapper.DeterministicAeadWrapper(
        ).wrap(pset)
        new_ciphertext = wrapped_daead.encrypt_deterministically(
            b'new_plaintext', b'new_associated_data')

        self.assertEqual(
            wrapped_daead.decrypt_deterministically(old_raw_ciphertext,
                                                    b'associated_data'),
            b'plaintext')
        self.assertEqual(
            wrapped_daead.decrypt_deterministically(new_ciphertext,
                                                    b'new_associated_data'),
            b'new_plaintext')
  def test_decrypt_wrong_associated_data_fails(self):
    primitive, key = self.new_primitive_key_pair(1234, tink_pb2.TINK)
    pset = primitive_set.new_primitive_set(deterministic_aead.DeterministicAead)
    entry = pset.add_primitive(primitive, key)
    pset.set_primary(entry)
    wrapped_daead = deterministic_aead_wrapper.DeterministicAeadWrapper().wrap(
        pset)

    ciphertext = wrapped_daead.encrypt_deterministically(
        b'plaintext', b'associated_data')
    with self.assertRaisesRegex(tink_error.TinkError, 'Decryption failed'):
      wrapped_daead.decrypt_deterministically(ciphertext,
                                              b'wrong_associated_data')
  def test_encrypt_decrypt(self):
    primitive, key = self.new_primitive_key_pair(1234, tink_pb2.TINK)
    pset = primitive_set.new_primitive_set(deterministic_aead.DeterministicAead)
    entry = pset.add_primitive(primitive, key)
    pset.set_primary(entry)

    wrapped_daead = deterministic_aead_wrapper.DeterministicAeadWrapper().wrap(
        pset)

    plaintext = b'plaintext'
    associated_data = b'associated_data'
    ciphertext = wrapped_daead.encrypt_deterministically(
        plaintext, associated_data)
    self.assertEqual(
        wrapped_daead.decrypt_deterministically(ciphertext, associated_data),
        plaintext)
  def test_decrypt_unknown_ciphertext_fails(self):
    unknown_primitive = helper.FakeDeterministicAead(
        'unknownFakeDeterministicAead')
    unknown_ciphertext = unknown_primitive.encrypt_deterministically(
        b'plaintext', b'associated_data')

    pset = primitive_set.new_primitive_set(deterministic_aead.DeterministicAead)
    primitive, raw_key = self.new_primitive_key_pair(1234, tink_pb2.RAW)
    new_primitive, new_key = self.new_primitive_key_pair(5678, tink_pb2.TINK)
    pset.add_primitive(primitive, raw_key)
    new_entry = pset.add_primitive(new_primitive, new_key)
    pset.set_primary(new_entry)
    wrapped_daead = deterministic_aead_wrapper.DeterministicAeadWrapper().wrap(
        pset)

    with self.assertRaisesRegex(tink_error.TinkError, 'Decryption failed'):
      wrapped_daead.decrypt_deterministically(unknown_ciphertext,
                                              b'associated_data')