예제 #1
0
 def test_edit_labels_on_gce_role(self,
                                  label,
                                  add=None,
                                  remove=None,
                                  create_role_first=True,
                                  raises=None,
                                  exception_message=''):
     role_name = 'hvac'
     project_id = 'test-hvac-project-not-a-real-project'
     if create_role_first:
         self.client.gcp.auth.create_role(
             name=role_name,
             role_type='gce',
             project_id=project_id,
             bound_service_accounts=[
                 '*****@*****.**'
             ],
             mount_point=self.TEST_MOUNT_POINT,
         )
     if raises:
         with self.assertRaises(raises) as cm:
             self.client.gcp.auth.edit_labels_on_gce_role(
                 name=role_name,
                 add=add,
                 remove=remove,
                 mount_point=self.TEST_MOUNT_POINT,
             )
         self.assertIn(
             member=exception_message,
             container=str(cm.exception),
         )
     else:
         edit_labled_response = self.client.gcp.auth.edit_labels_on_gce_role(
             name=role_name,
             add=add,
             remove=remove,
             mount_point=self.TEST_MOUNT_POINT,
         )
         logging.debug('create_role_response: %s' % edit_labled_response)
         if utils.skip_if_vault_version_lt('0.10.0'):
             expected_status_code = 204
         else:
             expected_status_code = 200  # TODO => figure out why this isn't a 204?
         self.assertEqual(
             first=edit_labled_response.status_code,
             second=expected_status_code,
         )
예제 #2
0
 def test_create_role(self,
                      label,
                      role_type,
                      policies=None,
                      extra_params=None,
                      raises=None,
                      exception_message=''):
     role_name = 'hvac'
     project_id = 'test-hvac-project-not-a-real-project'
     if extra_params is None:
         extra_params = {}
     if raises:
         with self.assertRaises(raises) as cm:
             self.client.gcp.auth.create_role(
                 name=role_name,
                 role_type=role_type,
                 project_id=project_id,
                 policies=policies,
                 mount_point=self.TEST_MOUNT_POINT,
                 **extra_params)
         self.assertIn(
             member=exception_message,
             container=str(cm.exception),
         )
     else:
         create_role_response = self.client.gcp.auth.create_role(
             name=role_name,
             role_type=role_type,
             project_id=project_id,
             policies=policies,
             mount_point=self.TEST_MOUNT_POINT,
             **extra_params)
         logging.debug('create_role_response: %s' % create_role_response)
         if utils.skip_if_vault_version_lt('0.10.0'):
             expected_status_code = 204
         else:
             expected_status_code = 200  # TODO => figure out why this isn't a 204?
         self.assertEqual(
             first=create_role_response.status_code,
             second=expected_status_code,
         )
예제 #3
0
from unittest import TestCase
from unittest import skipIf

from hvac.tests import utils


@skipIf(utils.skip_if_vault_version_lt('0.9.0'),
        "Policy class uses new parameters added >= Vault 0.9.0")
class TestPolicy(utils.HvacIntegrationTestCase, TestCase):
    def test_policy_manipulation(self):
        self.assertIn(
            member='root',
            container=self.client.sys.list_policies()['data']['policies'],
        )
        self.assertIsNone(self.client.get_policy('test'))
        policy, parsed_policy = self.prep_policy('test')
        self.assertIn(
            member='test',
            container=self.client.sys.list_policies()['data']['policies'],
        )
        self.assertEqual(policy,
                         self.client.sys.read_policy('test')['data']['rules'])
        self.assertEqual(parsed_policy,
                         self.client.get_policy('test', parse=True))

        self.client.sys.delete_policy(name='test', )
        self.assertNotIn(
            member='test',
            container=self.client.sys.list_policies()['data']['policies'],
        )
예제 #4
0
class TestLdap(utils.HvacIntegrationTestCase, TestCase):
    TEST_LDAP_PATH = 'test-ldap'
    ldap_server = None
    mock_server_port = None
    mock_ldap_url = None

    @classmethod
    def setUpClass(cls):
        super(TestLdap, cls).setUpClass()
        logging.getLogger('ldap_test').setLevel(logging.ERROR)

        cls.mock_server_port = utils.get_free_port()
        cls.mock_ldap_url = 'ldap://localhost:{port}'.format(
            port=cls.mock_server_port)
        cls.ldap_server = LdapServer({
            'port': cls.mock_server_port,
            'bind_dn': LDAP_BIND_DN,
            'password': LDAP_BIND_PASSWORD,
            'base': {
                'objectclass': ['domain'],
                'dn': LDAP_BASE_DN,
                'attributes': {
                    'dc': LDAP_BASE_DC
                }
            },
            'entries': LDAP_ENTRIES,
        })
        cls.ldap_server.start()

    @classmethod
    def tearDownClass(cls):
        super(TestLdap, cls).tearDownClass()
        cls.ldap_server.stop()

    def setUp(self):
        super(TestLdap, self).setUp()
        if 'ldap/' not in self.client.list_auth_backends():
            self.client.sys.enable_auth_method(method_type='ldap',
                                               path=self.TEST_LDAP_PATH)

    def tearDown(self):
        super(TestLdap, self).tearDown()
        self.client.disable_auth_backend(mount_point=self.TEST_LDAP_PATH, )

    @parameterized.expand([
        ('update url', dict(url=LDAP_URL)),
        ('update binddn',
         dict(url=LDAP_URL, bind_dn='cn=vault,ou=Users,dc=hvac,dc=network')),
        ('update upn_domain', dict(url=LDAP_URL, upn_domain='hvac.network')),
        ('update certificate',
         dict(url=LDAP_URL,
              certificate=utils.load_test_data('server-cert.pem'))),
        ('incorrect tls version', dict(url=LDAP_URL, tls_min_version='cats'),
         exceptions.InvalidRequest, "invalid 'tls_min_version'"),
    ])
    def test_configure(self,
                       test_label,
                       parameters,
                       raises=None,
                       exception_message=''):
        parameters.update({
            'user_dn': LDAP_USERS_DN,
            'group_dn': LDAP_GROUPS_DN,
            'mount_point': self.TEST_LDAP_PATH,
        })
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.configure(**parameters)
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            expected_status_code = 204
            configure_response = self.client.auth.ldap.configure(**parameters)
            self.assertEqual(first=expected_status_code,
                             second=configure_response.status_code)

            read_config_response = self.client.auth.ldap.read_configuration(
                mount_point=self.TEST_LDAP_PATH, )
            for parameter, argument in parameters.items():
                if parameter == 'mount_point':
                    continue
                self.assertIn(
                    member=argument,
                    container=read_config_response['data'].values(),
                )

    def test_read_configuration(self):
        response = self.client.auth.ldap.read_configuration(
            mount_point=self.TEST_LDAP_PATH, )
        self.assertIn(
            member='data',
            container=response,
        )

    @parameterized.expand([
        ('no policies', 'cats'),
        ('policies as list', 'cats', ['purr-policy']),
        ('policies as invalid type', 'cats', 'purr-policy',
         exceptions.ParamValidationError,
         '"policies" argument must be an instance of list'),
    ])
    def test_create_or_update_group(self,
                                    test_label,
                                    name,
                                    policies=None,
                                    raises=None,
                                    exception_message=''):
        expected_status_code = 204
        if raises:
            with self.assertRaises(raises) as cm:
                create_response = self.client.auth.ldap.create_or_update_group(
                    name=name,
                    policies=policies,
                    mount_point=self.TEST_LDAP_PATH,
                )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            create_response = self.client.auth.ldap.create_or_update_group(
                name=name,
                policies=policies,
                mount_point=self.TEST_LDAP_PATH,
            )
            self.assertEqual(first=expected_status_code,
                             second=create_response.status_code)

    @parameterized.expand([
        ('read configured groups', 'cats'),
        ('non-existent groups', 'cats', False, exceptions.InvalidPath),
    ])
    def test_list_groups(self,
                         test_label,
                         name,
                         configure_first=True,
                         raises=None,
                         exception_message=None):
        if configure_first:
            self.client.auth.ldap.create_or_update_group(
                name=name,
                mount_point=self.TEST_LDAP_PATH,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.list_groups(
                    mount_point=self.TEST_LDAP_PATH, )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            list_groups_response = self.client.auth.ldap.list_groups(
                mount_point=self.TEST_LDAP_PATH, )
            # raise Exception(list_groups_response)
            self.assertDictEqual(
                d1=dict(keys=[name]),
                d2=list_groups_response['data'],
            )

    @parameterized.expand([
        ('read configured group', 'cats'),
        ('non-existent group', 'cats', False, exceptions.InvalidPath),
    ])
    def test_read_group(self,
                        test_label,
                        name,
                        configure_first=True,
                        raises=None,
                        exception_message=None):
        if configure_first:
            self.client.auth.ldap.create_or_update_group(
                name=name,
                mount_point=self.TEST_LDAP_PATH,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.read_group(
                    name=name,
                    mount_point=self.TEST_LDAP_PATH,
                )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            read_group_response = self.client.auth.ldap.read_group(
                name=name,
                mount_point=self.TEST_LDAP_PATH,
            )
            self.assertIn(
                member='policies',
                container=read_group_response['data'],
            )

    @parameterized.expand([
        ('no policies or groups', 'cats'),
        ('policies as list', 'cats', ['purr-policy']),
        ('policies as invalid type', 'cats', 'purr-policy', None,
         exceptions.ParamValidationError,
         '"policies" argument must be an instance of list'),
        ('no groups', 'cats', ['purr-policy']),
        ('groups as list', 'cats', None, ['meow-group']),
        ('groups as invalid type', 'cats', None, 'meow-group',
         exceptions.ParamValidationError,
         '"groups" argument must be an instance of list'),
    ])
    def test_create_or_update_user(self,
                                   test_label,
                                   username,
                                   policies=None,
                                   groups=None,
                                   raises=None,
                                   exception_message=''):
        expected_status_code = 204
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.create_or_update_user(
                    username=username,
                    policies=policies,
                    groups=groups,
                    mount_point=self.TEST_LDAP_PATH,
                )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            create_response = self.client.auth.ldap.create_or_update_user(
                username=username,
                policies=policies,
                groups=groups,
                mount_point=self.TEST_LDAP_PATH,
            )
            self.assertEqual(first=expected_status_code,
                             second=create_response.status_code)

    @parameterized.expand([
        ('read configured group', 'cats'),
        ('non-existent group', 'cats', False, exceptions.InvalidPath),
    ])
    def test_delete_group(self,
                          test_label,
                          name,
                          configure_first=True,
                          raises=None,
                          exception_message=None):
        if configure_first:
            self.client.auth.ldap.create_or_update_group(
                name=name,
                mount_point=self.TEST_LDAP_PATH,
            )
        expected_status_code = 204
        delete_group_response = self.client.auth.ldap.delete_group(
            name=name,
            mount_point=self.TEST_LDAP_PATH,
        )
        self.assertEqual(first=expected_status_code,
                         second=delete_group_response.status_code)

    @parameterized.expand([
        ('read configured user', 'cats'),
        ('non-existent user', 'cats', False, exceptions.InvalidPath),
    ])
    def test_list_users(self,
                        test_label,
                        username,
                        configure_first=True,
                        raises=None,
                        exception_message=None):
        if configure_first:
            self.client.auth.ldap.create_or_update_user(
                username=username,
                mount_point=self.TEST_LDAP_PATH,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.list_users(
                    mount_point=self.TEST_LDAP_PATH, )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            list_users_response = self.client.auth.ldap.list_users(
                mount_point=self.TEST_LDAP_PATH, )
            self.assertDictEqual(
                d1=dict(keys=[username]),
                d2=list_users_response['data'],
            )

    @parameterized.expand([
        ('read configured user', 'cats'),
        ('non-existent user', 'cats', False, exceptions.InvalidPath),
    ])
    def test_read_user(self,
                       test_label,
                       username,
                       configure_first=True,
                       raises=None,
                       exception_message=None):
        if configure_first:
            self.client.auth.ldap.create_or_update_user(
                username=username,
                mount_point=self.TEST_LDAP_PATH,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.read_user(
                    username=username,
                    mount_point=self.TEST_LDAP_PATH,
                )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            read_user_response = self.client.auth.ldap.read_user(
                username=username,
                mount_point=self.TEST_LDAP_PATH,
            )
            self.assertIn(
                member='policies',
                container=read_user_response['data'],
            )

    @parameterized.expand([
        ('read configured user', 'cats'),
        ('non-existent user', 'cats', False, exceptions.InvalidPath),
    ])
    def test_delete_user(self,
                         test_label,
                         username,
                         configure_first=True,
                         raises=None,
                         exception_message=None):
        if configure_first:
            self.client.auth.ldap.create_or_update_user(
                username=username,
                mount_point=self.TEST_LDAP_PATH,
            )
        expected_status_code = 204
        delete_user_response = self.client.auth.ldap.delete_user(
            username=username,
            mount_point=self.TEST_LDAP_PATH,
        )
        self.assertEqual(first=expected_status_code,
                         second=delete_user_response.status_code)

    @parameterized.expand([
        param(label='working creds with policy'),
        param(
            label='invalid creds',
            username='******',
            password='******',
            attach_policy=False,
            raises=exceptions.InvalidRequest,
        ),
        # The following two test cases cover either side of the associated changelog entry for LDAP auth here:
        # https://github.com/hashicorp/vault/blob/master/CHANGELOG.md#0103-june-20th-2018
        param(
            label='working creds no membership with Vault version >= 0.10.3',
            attach_policy=False,
            skip_due_to_vault_version=utils.skip_if_vault_version_lt('0.10.3'),
        ),
        param(
            label='working creds no membership with Vault version < 0.10.3',
            attach_policy=False,
            raises=exceptions.InvalidRequest,
            exception_message='user is not a member of any authorized group',
            skip_due_to_vault_version=utils.skip_if_vault_version_ge('0.10.3'),
        ),
    ])
    def test_login(self,
                   label,
                   username=LDAP_USER_NAME,
                   password=LDAP_USER_PASSWORD,
                   attach_policy=True,
                   raises=None,
                   exception_message='',
                   skip_due_to_vault_version=False):
        if skip_due_to_vault_version:
            self.skipTest(
                reason='test case does not apply to Vault version under test')

        test_policy_name = 'test-ldap-policy'
        self.client.auth.ldap.configure(
            url=self.mock_ldap_url,
            bind_dn=self.ldap_server.config['bind_dn'],
            bind_pass=self.ldap_server.config['password'],
            user_dn=LDAP_USERS_DN,
            user_attr='uid',
            group_dn=LDAP_GROUPS_DN,
            group_attr='cn',
            insecure_tls=True,
            mount_point=self.TEST_LDAP_PATH,
        )

        if attach_policy:
            self.prep_policy(test_policy_name)
            self.client.auth.ldap.create_or_update_group(
                name=LDAP_GROUP_NAME,
                policies=[test_policy_name],
                mount_point=self.TEST_LDAP_PATH,
            )

        if raises:
            with self.assertRaises(raises) as cm:
                self.client.auth.ldap.login(
                    username=username,
                    password=password,
                    mount_point=self.TEST_LDAP_PATH,
                )
            if exception_message is not None:
                self.assertIn(
                    member=exception_message,
                    container=str(cm.exception),
                )
        else:
            login_response = self.client.auth.ldap.login(
                username=username,
                password=password,
                mount_point=self.TEST_LDAP_PATH,
            )
            self.assertDictEqual(
                d1=dict(username=username),
                d2=login_response['auth']['metadata'],
            )
            self.assertEqual(
                first=login_response['auth']['client_token'],
                second=self.client.token,
            )
            if attach_policy:
                expected_policies = ['default', test_policy_name]
            else:
                expected_policies = ['default']
            self.assertEqual(first=expected_policies,
                             second=login_response['auth']['policies'])
예제 #5
0
import logging
from unittest import TestCase
from unittest import skipIf

from parameterized import parameterized

from hvac import exceptions
from hvac.tests import utils


@skipIf(utils.skip_if_vault_version_lt('0.11.0'), "Azure secret engine not available before Vault version 0.11.0")
class TestAzure(utils.HvacIntegrationTestCase, TestCase):
    TENANT_ID = '00000000-0000-0000-0000-000000000000'
    SUBSCRIPTION_ID = '00000000-0000-0000-0000-000000000000'
    DEFAULT_MOUNT_POINT = 'azure-integration-test'

    def setUp(self):
        super(TestAzure, self).setUp()
        self.client.enable_secret_backend(
            backend_type='azure',
            mount_point=self.DEFAULT_MOUNT_POINT,
        )

    def tearDown(self):
        self.client.disable_secret_backend(mount_point=self.DEFAULT_MOUNT_POINT)
        super(TestAzure, self).tearDown()

    @parameterized.expand([
        ('no parameters',),
        ('valid environment argument', 'AzureUSGovernmentCloud'),
        ('invalid environment argument', 'AzureCityKity', exceptions.ParamValidationError, 'invalid environment argument provided'),
예제 #6
0
class TestTransit(utils.HvacIntegrationTestCase, TestCase):
    TEST_MOUNT_POINT = 'transit-integration-test'

    def setUp(self):
        super(TestTransit, self).setUp()
        self.client.enable_secret_backend(
            backend_type='transit',
            mount_point=self.TEST_MOUNT_POINT,
        )

    def tearDown(self):
        self.client.disable_secret_backend(mount_point=self.TEST_MOUNT_POINT)
        super(TestTransit, self).tearDown()

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_create_key(self, label, raises=False, exception_message=''):
        key_name = 'testkey'
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.create_key(
                    name=key_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            create_key_response = self.client.secrets.transit.create_key(
                name=key_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('create_key_response: %s' % create_key_response)
            self.assertEqual(
                first=create_key_response.status_code,
                second=204,
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_read_key(self, label, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.read_key(
                    name=key_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            read_key_response = self.client.secrets.transit.read_key(
                name=key_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('read_key_response: %s' % read_key_response)
            self.assertEqual(
                first=read_key_response['data']['name'],
                second=key_name,
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_list_keys(self, label, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.list_keys(
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            list_keys_response = self.client.secrets.transit.list_keys(
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('list_keys_response: %s' % list_keys_response)
            self.assertEqual(
                first=list_keys_response['data']['keys'],
                second=[key_name],
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_delete_key(self, label, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        update_key_configuration_response = self.client.secrets.transit.update_key_configuration(
            name=key_name,
            deletion_allowed=True,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('update_key_configuration_response: %s' % update_key_configuration_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.delete_key(
                    name=key_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            delete_key_response = self.client.secrets.transit.delete_key(
                name=key_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('delete_key_response: %s' % delete_key_response)
            self.assertEqual(
                first=delete_key_response.status_code,
                second=204,
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_rotate_key(self, label, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.rotate_key(
                    name=key_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            rotate_key_response = self.client.secrets.transit.rotate_key(
                name=key_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('rotate_key_response: %s' % rotate_key_response)
            self.assertEqual(
                first=rotate_key_response.status_code,
                second=204,
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'invalid key type',
            key_type='kitty-cat-key',
            raises=exceptions.ParamValidationError,
            exception_message='invalid key_type argument provided',
        ),
    ])
    def test_export_key(self, label, key_type='hmac-key', raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            exportable=True,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.export_key(
                    name=key_name,
                    key_type=key_type,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            export_key_response = self.client.secrets.transit.export_key(
                name=key_name,
                key_type=key_type,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('export_key_response: %s' % export_key_response)
            self.assertEqual(
                first=len(export_key_response['data']['keys']),
                second=1,
            )
            self.assertEqual(
                first=export_key_response['data']['name'],
                second=key_name,
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_encrypt_data(self, label, plaintext='hi itsame hvac', raises=False, exception_message=''):
        key_name = 'testkey'
        plaintext = utils.base64ify(plaintext)
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.encrypt_data(
                    name=key_name,
                    plaintext=plaintext,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            encrypt_data_response = self.client.secrets.transit.encrypt_data(
                name=key_name,
                plaintext=plaintext,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('encrypt_data_response: %s' % encrypt_data_response)
            self.assertIn(
                member='ciphertext',
                container=encrypt_data_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_decrypt_data(self, label, plaintext='hi itsame hvac', raises=False, exception_message=''):
        key_name = 'testkey'
        plaintext = utils.base64ify(plaintext)
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        encrypt_data_response = self.client.secrets.transit.encrypt_data(
            name=key_name,
            plaintext=plaintext,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('encrypt_data_response: %s' % encrypt_data_response)
        ciphertext = encrypt_data_response['data']['ciphertext']
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.decrypt_data(
                    name=key_name,
                    ciphertext=ciphertext,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            decrypt_data_response = self.client.secrets.transit.decrypt_data(
                name=key_name,
                ciphertext=ciphertext,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('decrypt_data_response: %s' % decrypt_data_response)
            self.assertIn(
                member=plaintext,
                container=decrypt_data_response['data']['plaintext'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_rewrap_data(self, label, plaintext='hi itsame hvac', raises=False, exception_message=''):
        key_name = 'testkey'
        plaintext = utils.base64ify(plaintext)
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        encrypt_data_response = self.client.secrets.transit.encrypt_data(
            name=key_name,
            plaintext=plaintext,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('encrypt_data_response: %s' % encrypt_data_response)
        ciphertext = encrypt_data_response['data']['ciphertext']
        rotate_key_response = self.client.secrets.transit.rotate_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('rotate_key_response: %s' % rotate_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.rewrap_data(
                    name=key_name,
                    ciphertext=ciphertext,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            rewrap_data_response = self.client.secrets.transit.rewrap_data(
                name=key_name,
                ciphertext=ciphertext,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('rewrap_data_response: %s' % rewrap_data_response)
            self.assertIn(
                member='ciphertext',
                container=rewrap_data_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'invalid key type',
            key_type='kitty-cat-key',
            raises=exceptions.ParamValidationError,
            exception_message='invalid key_type argument provided',
        ),
    ])
    def test_generate_data_key(self, label, key_type='plaintext', raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.generate_data_key(
                    name=key_name,
                    key_type=key_type,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            gen_data_key_response = self.client.secrets.transit.generate_data_key(
                name=key_name,
                key_type=key_type,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('gen_data_key_response: %s' % gen_data_key_response)
            self.assertIn(
                member='ciphertext',
                container=gen_data_key_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    def test_generate_random_bytes(self, label, n_bytes=32, raises=False, exception_message=''):
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.generate_random_bytes(
                    n_bytes=n_bytes,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            gen_bytes_response = self.client.secrets.transit.generate_random_bytes(
                n_bytes=n_bytes,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('gen_data_key_response: %s' % gen_bytes_response)
            self.assertIn(
                member='random_bytes',
                container=gen_bytes_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'invalid algorithm',
            algorithm='meow2-256',
            raises=exceptions.ParamValidationError,
            exception_message='invalid algorithm argument provided',
        ),
        param(
            'invalid output_format',
            output_format='kitty64',
            raises=exceptions.ParamValidationError,
            exception_message='invalid output_format argument provided',
        ),
    ])
    def test_hash_data(self, label, hash_input='hash this ish', algorithm='sha2-256', output_format='hex', raises=False, exception_message=''):
        hash_input = utils.base64ify(hash_input)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.hash_data(
                    hash_input=hash_input,
                    algorithm=algorithm,
                    output_format=output_format,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            hash_data_response = self.client.secrets.transit.hash_data(
                hash_input=hash_input,
                algorithm=algorithm,
                output_format=output_format,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('hash_data_response: %s' % hash_data_response)
            self.assertIn(
                member='sum',
                container=hash_data_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'invalid algorithm',
            algorithm='meow2-256',
            raises=exceptions.ParamValidationError,
            exception_message='invalid algorithm argument provided',
        ),
    ])
    def test_generate_hmac(self, label, hash_input='hash this ish', algorithm='sha2-256', raises=False, exception_message=''):
        hash_input = utils.base64ify(hash_input)
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.generate_hmac(
                    name=key_name,
                    hash_input=hash_input,
                    algorithm=algorithm,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            generate_hmac_response = self.client.secrets.transit.generate_hmac(
                name=key_name,
                hash_input=hash_input,
                algorithm=algorithm,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('generate_hmac_response: %s' % generate_hmac_response)
            self.assertIn(
                member='hmac',
                container=generate_hmac_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'invalid algorithm',
            hash_algorithm='meow2-256',
            raises=exceptions.ParamValidationError,
            exception_message='invalid hash_algorithm argument provided',
        ),
        param(
            'invalid signature_algorithm',
            signature_algorithm='pre-shared kitty cats',
            raises=exceptions.ParamValidationError,
            exception_message='invalid signature_algorithm argument provided',
        ),
    ])
    def test_sign_data(self, label, hash_input='hash this ish', hash_algorithm='sha2-256', signature_algorithm='pss',
                       raises=False, exception_message=''):
        hash_input = utils.base64ify(hash_input)
        key_name = 'testkey'
        key_type = 'ed25519'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            key_type=key_type,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.sign_data(
                    name=key_name,
                    hash_input=hash_input,
                    hash_algorithm=hash_algorithm,
                    signature_algorithm=signature_algorithm,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            sign_data_response = self.client.secrets.transit.sign_data(
                name=key_name,
                hash_input=hash_input,
                hash_algorithm=hash_algorithm,
                signature_algorithm=signature_algorithm,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('sign_data_response: %s' % sign_data_response)
            self.assertIn(
                member='signature',
                container=sign_data_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'invalid algorithm',
            hash_algorithm='meow2-256',
            raises=exceptions.ParamValidationError,
            exception_message='invalid hash_algorithm argument provided',
        ),
        param(
            'invalid signature_algorithm',
            signature_algorithm='pre-shared kitty cats',
            raises=exceptions.ParamValidationError,
            exception_message='invalid signature_algorithm argument provided',
        ),
    ])
    def test_verify_signed_data(self, label, hash_input='hash this ish', hash_algorithm='sha2-256', signature_algorithm='pss',
                                raises=False, exception_message=''):
        hash_input = utils.base64ify(hash_input)
        key_name = 'testkey'
        key_type = 'ed25519'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            key_type=key_type,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        sign_data_response = self.client.secrets.transit.sign_data(
            name=key_name,
            hash_input=hash_input,
            hash_algorithm='sha2-256',
            signature_algorithm='pss',
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('sign_data_response: %s' % sign_data_response)
        signature = sign_data_response['data']['signature']
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.verify_signed_data(
                    name=key_name,
                    hash_input=hash_input,
                    signature=signature,
                    hash_algorithm=hash_algorithm,
                    signature_algorithm=signature_algorithm,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            verify_signed_data_response = self.client.secrets.transit.verify_signed_data(
                name=key_name,
                hash_input=hash_input,
                signature=signature,
                hash_algorithm=hash_algorithm,
                signature_algorithm=signature_algorithm,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('verify_signed_data_response: %s' % verify_signed_data_response)
            self.assertTrue(
                expr=verify_signed_data_response['data']['valid'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'allow_plaintext_backup false',
            allow_plaintext_backup=False,
            raises=exceptions.InternalServerError,
            exception_message='plaintext backup is disallowed on the policy',
        ),
    ])
    @skipIf(utils.skip_if_vault_version_lt('0.9.1'), "transit key export/restore added in Vault versions >=0.9.1")
    def test_backup_key(self, label, allow_plaintext_backup=True, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        update_key_configuration_response = self.client.secrets.transit.update_key_configuration(
            name=key_name,
            exportable=True,
            allow_plaintext_backup=allow_plaintext_backup,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('update_key_configuration_response: %s' % update_key_configuration_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.backup_key(
                    name=key_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            backup_key_response = self.client.secrets.transit.backup_key(
                name=key_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('backup_key_response: %s' % backup_key_response)
            self.assertIn(
                member='backup',
                container=backup_key_response['data'],
            )

    @parameterized.expand([
        param(
            'success',
        ),
        param(
            'success with force',
            force=True,
        ),
        param(
            'existing key without force',
            name=None,
            raises=exceptions.InternalServerError,
            exception_message='already exists',
        ),
    ])
    @skipIf(utils.skip_if_vault_version_lt('0.9.1'), "transit key export/restore added in Vault versions >=0.9.1")
    def test_restore_key(self, label, name='new_test_ky', force=False, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        update_key_configuration_response = self.client.secrets.transit.update_key_configuration(
            name=key_name,
            exportable=True,
            allow_plaintext_backup=True,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('update_key_configuration_response: %s' % update_key_configuration_response)
        backup_key_response = self.client.secrets.transit.backup_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('backup_key_response: %s' % backup_key_response)
        backup = backup_key_response['data']['backup']
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.restore_key(
                    backup=backup,
                    name=name,
                    force=force,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            restore_key_response = self.client.secrets.transit.restore_key(
                backup=backup,
                name=name,
                force=force,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('restore_key_response: %s' % restore_key_response)
            self.assertEqual(
                first=restore_key_response.status_code,
                second=204,
            )

    @parameterized.expand([
        param(
            'success',
        ),
    ])
    @skipIf(utils.skip_if_vault_version_lt('0.11.4'), "transit key trimming added in Vault versions >=0.11.4")
    def test_trim_key(self, label, min_version=2, raises=False, exception_message=''):
        key_name = 'testkey'
        create_key_response = self.client.secrets.transit.create_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('create_key_response: %s' % create_key_response)
        for _ in range(0, 10):
            rotate_key_response = self.client.secrets.transit.rotate_key(
                name=key_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('rotate_key_response: %s' % rotate_key_response)

        update_key_configuration_response = self.client.secrets.transit.update_key_configuration(
            name=key_name,
            min_decryption_version=3,
            min_encryption_version=9,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('update_key_configuration_response: %s' % update_key_configuration_response)

        read_key_response = self.client.secrets.transit.read_key(
            name=key_name,
            mount_point=self.TEST_MOUNT_POINT,
        )
        logging.debug('read_key_response: %s' % read_key_response)
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.transit.trim_key(
                    name=key_name,
                    min_version=min_version,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            trim_key_response = self.client.secrets.transit.trim_key(
                name=key_name,
                min_version=min_version,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('trim_key_response: %s' % trim_key_response)
            self.assertEqual(
                first=trim_key_response.status_code,
                second=204,
            )
예제 #7
0
파일: test_azure.py 프로젝트: yijxiang/hvac
import logging
from unittest import TestCase
from unittest import skipIf

import requests_mock
from parameterized import parameterized

from hvac.adapters import Request
from hvac.api.auth import Azure
from hvac.tests import utils


@skipIf(utils.skip_if_vault_version_lt('0.10.0'), "Azure auth method not available before Vault version 0.10.0")
class TestAzure(TestCase):
    TEST_MOUNT_POINT = 'azure-test'

    @parameterized.expand([
        ('success', dict(), None,),
        ('with subscription_id', dict(subscription_id='my_subscription_id'), None,),
        ('with resource_group_name', dict(resource_group_name='my_resource_group_name'), None,),
        ('with vm_name', dict(vm_name='my_vm_name'), None,),
        ('with vmss_name', dict(vmss_name='my_vmss_name'), None,),
        ('with vm_name and vmss_name', dict(vm_name='my_vm_name', vmss_name='my_vmss_name'), None,),
    ])
    @requests_mock.Mocker()
    def test_login(self, label, test_params, raises, requests_mocker):
        role_name = 'hvac'
        test_policies = [
            "default",
            "dev",
            "prod",