示例#1
0
class ScenarioTest(ReplayableTest):
    def __init__(self, method_name, config_file=None, recording_name=None,
                 recording_processors=None, replay_processors=None, recording_patches=None, replay_patches=None):
        self.name_replacer = GeneralNameReplacer()
        super(ScenarioTest, self).__init__(
            method_name,
            config_file=config_file,
            recording_processors=recording_processors or [
                SubscriptionRecordingProcessor(MOCKED_SUBSCRIPTION_ID),
                OAuthRequestResponsesFilter(),
                LargeRequestBodyProcessor(),
                LargeResponseBodyProcessor(),
                DeploymentNameReplacer(),
                RequestUrlNormalizer(),
                self.name_replacer,
            ],
            replay_processors=replay_processors or [
                LargeResponseBodyReplacer(),
                DeploymentNameReplacer(),
                RequestUrlNormalizer(),
            ],
            recording_patches=recording_patches or [patch_main_exception_handler],
            replay_patches=replay_patches or [
                patch_main_exception_handler,
                patch_time_sleep_api,
                patch_long_run_operation_delay,
                patch_load_cached_subscriptions,
                patch_retrieve_token_for_user,
                patch_progress_controller,
            ],
            recording_dir=find_recording_dir(inspect.getfile(self.__class__)),
            recording_name=recording_name
        )

    def create_random_name(self, prefix, length):
        self.test_resources_count += 1
        moniker = '{}{:06}'.format(prefix, self.test_resources_count)

        if self.in_recording:
            name = create_random_name(prefix, length)
            self.name_replacer.register_name_pair(name, moniker)
            return name

        return moniker

    @classmethod
    def cmd(cls, command, checks=None, expect_failure=False):
        return execute(command, expect_failure=expect_failure).assert_with_checks(checks)

    def get_subscription_id(self):
        if self.in_recording or self.is_live:
            subscription_id = self.cmd('account list --query "[?isDefault].id" -o tsv').output.strip()
        else:
            subscription_id = MOCKED_SUBSCRIPTION_ID
        return subscription_id
示例#2
0
class ScenarioTest(ReplayableTest):
    def __init__(self,
                 method_name,
                 config_file=None,
                 recording_dir=None,
                 recording_name=None,
                 recording_processors=None,
                 replay_processors=None,
                 recording_patches=None,
                 replay_patches=None):
        self.name_replacer = GeneralNameReplacer()
        super(ScenarioTest, self).__init__(
            method_name,
            config_file=config_file,
            recording_processors=recording_processors or [
                SubscriptionRecordingProcessor(MOCKED_SUBSCRIPTION_ID),
                OAuthRequestResponsesFilter(),
                LargeRequestBodyProcessor(),
                LargeResponseBodyProcessor(),
                DeploymentNameReplacer(),
                self.name_replacer,
            ],
            replay_processors=replay_processors or [
                LargeResponseBodyReplacer(),
                DeploymentNameReplacer(),
            ],
            recording_patches=recording_patches
            or [patch_main_exception_handler],
            replay_patches=replay_patches or [
                patch_main_exception_handler,
                patch_time_sleep_api,
                patch_long_run_operation_delay,
                patch_load_cached_subscriptions,
                patch_retrieve_token_for_user,
                patch_progress_controller,
            ],
            recording_dir=recording_dir
            or find_recording_dir(inspect.getfile(self.__class__)),
            recording_name=recording_name)

    def create_random_name(self, prefix, length):
        self.test_resources_count += 1
        moniker = '{}{:06}'.format(prefix, self.test_resources_count)

        if self.in_recording:
            name = create_random_name(prefix, length)
            self.name_replacer.register_name_pair(name, moniker)
            return name

        return moniker

    @classmethod
    def cmd(cls, command, checks=None, expect_failure=False):
        return execute(
            command, expect_failure=expect_failure).assert_with_checks(checks)
示例#3
0
class StorageExampleTest(ReplayableTest):
    def __init__(self, method_name, **kwargs):
        self.scrubber = GeneralNameReplacer()
        super(StorageExampleTest, self).__init__(
            method_name,
            config_file=TEST_CONFIG,
            recording_processors=[
                self.scrubber,
                SubscriptionRecordingProcessor(DUMMY_UUID),
                AccessTokenReplacer(),
            ],
            replay_patches=[
                patch_long_run_operation_delay,
            ]
        )
        if self.is_live:
            constants_to_scrub = [
                (os.environ['AZURE_CLIENT_ID'], DUMMY_UUID),
                (os.environ['AZURE_CLIENT_SECRET'], DUMMY_SECRET),
                (os.environ['AZURE_TENANT_ID'], DUMMY_UUID),
                (STORAGE_ACCOUNT_NAME, DUMMY_STORAGE_NAME)
            ]
            for key, replacement in constants_to_scrub:
                self.scrubber.register_name_pair(key, replacement)
                self.scrubber.register_name_pair(quote_plus(key), replacement)

    @staticmethod
    def fake_credentials():
        return (
            BasicTokenAuthentication({'access_token': 'fake_token'}),
            DUMMY_UUID
        )

    def test_example(self):
        if self.is_live:
            run_example()
        else:
            with mock.patch('example.get_credentials', StorageExampleTest.fake_credentials), \
                    mock.patch('example.STORAGE_ACCOUNT_NAME', DUMMY_STORAGE_NAME):
                run_example()
class VirtualMachineExampleTest(ReplayableTest):
    def __init__(self, method_name, **kwargs):
        self.scrubber = GeneralNameReplacer()
        super(VirtualMachineExampleTest, self).__init__(
            method_name,
            config_file=TEST_CONFIG,
            recording_processors=[
                self.scrubber,
                SubscriptionRecordingProcessor(DUMMY_UUID),
                AccessTokenReplacer(),
            ],
            replay_patches=[
                patch_long_run_operation_delay,
            ]
        )
        if self.in_recording:
            constants_to_scrub = [
                (os.environ['AZURE_CLIENT_ID'], DUMMY_UUID),
                (os.environ['AZURE_CLIENT_SECRET'], DUMMY_SECRET),
                (os.environ['AZURE_TENANT_ID'], DUMMY_UUID),
                (STORAGE_ACCOUNT_NAME, DUMMY_STORAGE_NAME)
            ]
            for key, replacement in constants_to_scrub:
                self.scrubber.register_name_pair(key, replacement)
                self.scrubber.register_name_pair(quote_plus(key), replacement)

    @staticmethod
    def fake_credentials():
        return (
            BasicTokenAuthentication({'access_token': 'fake_token'}),
            DUMMY_UUID
        )

    def test_example(self):
        if self.in_recording:
            run_example()
        else:
            with patch('example.get_credentials', VirtualMachineExampleTest.fake_credentials), \
                 patch('example.STORAGE_ACCOUNT_NAME', DUMMY_STORAGE_NAME):
                run_example()
class AzureTestCase(ReplayableTest):
    def __init__(self,
                 method_name,
                 config_file=None,
                 recording_dir=None,
                 recording_name=None,
                 recording_processors=None,
                 replay_processors=None,
                 recording_patches=None,
                 replay_patches=None):
        self.working_folder = os.path.dirname(__file__)
        self.qualified_test_name = get_qualified_method_name(self, method_name)
        self._fake_settings, self._real_settings = self._load_settings()
        self.scrubber = GeneralNameReplacer()
        config_file = config_file or os.path.join(self.working_folder,
                                                  TEST_SETTING_FILENAME)
        if not os.path.exists(config_file):
            config_file = None
        super(AzureTestCase, self).__init__(
            method_name,
            config_file=config_file,
            recording_dir=recording_dir,
            recording_name=recording_name or self.qualified_test_name,
            recording_processors=recording_processors
            or self._get_recording_processors(),
            replay_processors=replay_processors
            or self._get_replay_processors(),
            recording_patches=recording_patches,
            replay_patches=replay_patches,
        )

    @property
    def settings(self):
        if self.is_live:
            if self._real_settings:
                return self._real_settings
            else:
                raise AzureTestError(
                    'Need a mgmt_settings_real.py file to run tests live.')
        else:
            return self._fake_settings

    def _load_settings(self):
        try:
            from . import mgmt_settings_real as real_settings
            return fake_settings, real_settings
        except ImportError:
            return fake_settings, None

    def _get_recording_processors(self):
        return [
            self.scrubber,
            OAuthRequestResponsesFilter(),
            RequestUrlNormalizer()
        ]

    def _get_replay_processors(self):
        return [RequestUrlNormalizer()]

    def is_playback(self):
        return not self.is_live

    def get_settings_value(self, key):
        key_value = os.environ.get("AZURE_" + key, None)

        if key_value and self._real_settings and getattr(
                self._real_settings, key) != key_value:
            raise ValueError(
                "You have both AZURE_{key} env variable and mgmt_settings_real.py for {key} to difference values"
                .format(key=key))

        if not key_value:
            try:
                key_value = getattr(self.settings, key)
            except Exception:
                print("Could not get {}".format(key))
                raise
        return key_value

    def set_value_to_scrub(self, key, default_value):
        if self.is_live:
            value = self.get_settings_value(key)
            self.scrubber.register_name_pair(value, default_value)
            return value
        else:
            return default_value

    def setUp(self):
        # Every test uses a different resource group name calculated from its
        # qualified test name.
        #
        # When running all tests serially, this allows us to delete
        # the resource group in teardown without waiting for the delete to
        # complete. The next test in line will use a different resource group,
        # so it won't have any trouble creating its resource group even if the
        # previous test resource group hasn't finished deleting.
        #
        # When running tests individually, if you try to run the same test
        # multiple times in a row, it's possible that the delete in the previous
        # teardown hasn't completed yet (because we don't wait), and that
        # would make resource group creation fail.
        # To avoid that, we also delete the resource group in the
        # setup, and we wait for that delete to complete.
        super(AzureTestCase, self).setUp()

    def tearDown(self):
        return super(AzureTestCase, self).tearDown()

    def create_basic_client(self, client_class, **kwargs):
        # Whatever the client, if credentials is None, fail
        with self.assertRaises(ValueError):
            client = client_class(credentials=None, **kwargs)

        tenant_id = os.environ.get("AZURE_TENANT_ID", None)
        client_id = os.environ.get("AZURE_CLIENT_ID", None)
        secret = os.environ.get("AZURE_CLIENT_SECRET", None)

        if tenant_id and client_id and secret and self.is_live:
            from msrestazure.azure_active_directory import ServicePrincipalCredentials
            credentials = ServicePrincipalCredentials(tenant=tenant_id,
                                                      client_id=client_id,
                                                      secret=secret)
        else:
            credentials = self.settings.get_credentials()

        # Real client creation
        client = client_class(credentials=credentials, **kwargs)
        if self.is_playback():
            client.config.long_running_operation_timeout = 0
        client.config.enable_http_logger = True
        return client

    def create_random_name(self, name):
        return get_resource_name(name, self.qualified_test_name.encode())

    def get_resource_name(self, name):
        """Alias to create_random_name for back compatibility."""
        return self.create_random_name(name)

    def get_preparer_resource_name(self, prefix):
        """Random name generation for use by preparers.

        If prefix is a blank string, use the fully qualified test name instead.
        This is what legacy tests do for resource groups."""
        return self.get_resource_name(
            prefix or self.qualified_test_name.replace('.', '_'))
class AzureTestCase(ReplayableTest):
    def __init__(self,
                 method_name,
                 config_file=None,
                 recording_dir=None,
                 recording_name=None,
                 recording_processors=None,
                 replay_processors=None,
                 recording_patches=None,
                 replay_patches=None,
                 **kwargs):
        self.working_folder = os.path.dirname(__file__)
        self.qualified_test_name = get_qualified_method_name(self, method_name)
        self._fake_settings, self._real_settings = self._load_settings()
        self.scrubber = GeneralNameReplacer()
        config_file = config_file or os.path.join(self.working_folder,
                                                  TEST_SETTING_FILENAME)
        if not os.path.exists(config_file):
            config_file = None
        load_dotenv(find_dotenv())
        super(AzureTestCase,
              self).__init__(method_name,
                             config_file=config_file,
                             recording_dir=recording_dir,
                             recording_name=recording_name
                             or self.qualified_test_name,
                             recording_processors=recording_processors
                             or self._get_recording_processors(),
                             replay_processors=replay_processors
                             or self._get_replay_processors(),
                             recording_patches=recording_patches,
                             replay_patches=replay_patches,
                             **kwargs)

    @property
    def settings(self):
        if self.is_live:
            if self._real_settings:
                return self._real_settings
            else:
                raise AzureTestError(
                    'Need a mgmt_settings_real.py file to run tests live.')
        else:
            return self._fake_settings

    def _load_settings(self):
        try:
            from . import mgmt_settings_real as real_settings
            return fake_settings, real_settings
        except ImportError:
            return fake_settings, None

    def _get_recording_processors(self):
        return [
            self.scrubber,
            AuthenticationMetadataFilter(),
            OAuthRequestResponsesFilter(),
            RequestUrlNormalizer()
        ]

    def _get_replay_processors(self):
        return [RequestUrlNormalizer()]

    def is_playback(self):
        return not self.is_live

    def get_settings_value(self, key):
        key_value = os.environ.get("AZURE_" + key, None)

        if key_value and self._real_settings and getattr(
                self._real_settings, key) != key_value:
            raise ValueError(
                "You have both AZURE_{key} env variable and mgmt_settings_real.py for {key} to different values"
                .format(key=key))

        if not key_value:
            try:
                key_value = getattr(self.settings, key)
            except Exception:
                print("Could not get {}".format(key))
                raise
        return key_value

    def set_value_to_scrub(self, key, default_value):
        if self.is_live:
            value = self.get_settings_value(key)
            self.scrubber.register_name_pair(value, default_value)
            return value
        else:
            return default_value

    def setUp(self):
        # Every test uses a different resource group name calculated from its
        # qualified test name.
        #
        # When running all tests serially, this allows us to delete
        # the resource group in teardown without waiting for the delete to
        # complete. The next test in line will use a different resource group,
        # so it won't have any trouble creating its resource group even if the
        # previous test resource group hasn't finished deleting.
        #
        # When running tests individually, if you try to run the same test
        # multiple times in a row, it's possible that the delete in the previous
        # teardown hasn't completed yet (because we don't wait), and that
        # would make resource group creation fail.
        # To avoid that, we also delete the resource group in the
        # setup, and we wait for that delete to complete.
        super(AzureTestCase, self).setUp()

    def tearDown(self):
        return super(AzureTestCase, self).tearDown()

    def get_credential(self, client_class, **kwargs):

        tenant_id = os.environ.get(
            "AZURE_TENANT_ID", getattr(self._real_settings, "TENANT_ID", None))
        client_id = os.environ.get(
            "AZURE_CLIENT_ID", getattr(self._real_settings, "CLIENT_ID", None))
        secret = os.environ.get(
            "AZURE_CLIENT_SECRET",
            getattr(self._real_settings, "CLIENT_SECRET", None))
        is_async = kwargs.pop("is_async", False)

        if tenant_id and client_id and secret and self.is_live:
            if _is_autorest_v3(client_class):
                # Create azure-identity class
                from azure.identity import ClientSecretCredential
                if is_async:
                    from azure.identity.aio import ClientSecretCredential
                return ClientSecretCredential(tenant_id=tenant_id,
                                              client_id=client_id,
                                              client_secret=secret)
            else:
                # Create msrestazure class
                from msrestazure.azure_active_directory import ServicePrincipalCredentials
                return ServicePrincipalCredentials(tenant=tenant_id,
                                                   client_id=client_id,
                                                   secret=secret)
        else:
            if _is_autorest_v3(client_class):
                if is_async:
                    if self.is_live:
                        raise ValueError(
                            "Async live doesn't support mgmt_setting_real, please set AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET"
                        )
                    return AsyncFakeCredential()
                else:
                    return self.settings.get_azure_core_credentials()
            else:
                return self.settings.get_credentials()

    def create_client_from_credential(self, client_class, credential,
                                      **kwargs):

        # Real client creation
        # TODO decide what is the final argument for that
        # if self.is_playback():
        #     kwargs.setdefault("polling_interval", 0)
        if _is_autorest_v3(client_class):
            kwargs.setdefault("logging_enable", True)
            client = client_class(credential=credential, **kwargs)
        else:
            client = client_class(credentials=credential, **kwargs)

        if self.is_playback():
            try:
                client._config.polling_interval = 0  # FIXME in azure-mgmt-core, make this a kwargs
            except AttributeError:
                pass

        if hasattr(client, "config"):  # Autorest v2
            if self.is_playback():
                client.config.long_running_operation_timeout = 0
            client.config.enable_http_logger = True
        return client

    def create_basic_client(self, client_class, **kwargs):
        """ DO NOT USE ME ANYMORE."""
        logger = logging.getLogger()
        logger.warning(
            "'create_basic_client' will be deprecated in the future. It is recommended that you use \
                'get_credential' and 'create_client_from_credential' to create your client."
        )

        credentials = self.get_credential(client_class)
        return self.create_client_from_credential(client_class, credentials,
                                                  **kwargs)

    def create_random_name(self, name):
        return get_resource_name(name, self.qualified_test_name.encode())

    def get_resource_name(self, name):
        """Alias to create_random_name for back compatibility."""
        return self.create_random_name(name)

    def get_replayable_random_resource_name(self, name):
        """In a replay scenario, (is not live) gives the static moniker.  In the random scenario, gives generated name."""
        if self.is_live:
            created_name = self.create_random_name(name)
            self.scrubber.register_name_pair(created_name, name)
        return name

    def get_preparer_resource_name(self, prefix):
        """Random name generation for use by preparers.

        If prefix is a blank string, use the fully qualified test name instead.
        This is what legacy tests do for resource groups."""
        return self.get_resource_name(
            prefix or self.qualified_test_name.replace('.', '_'))

    @staticmethod
    def await_prepared_test(test_fn):
        """Synchronous wrapper for async test methods. Used to avoid making changes
        upstream to AbstractPreparer, which only awaits async tests that use preparers.
        (Add @AzureTestCase.await_prepared_test decorator to async tests without preparers)

        # Note: this will only be needed so long as we maintain unittest.TestCase in our
        test-class inheritance chain.
        """

        if sys.version_info < (3, 5):
            raise ImportError(
                "Async wrapper is not needed for Python 2.7 code.")

        import asyncio

        @functools.wraps(test_fn)
        def run(test_class_instance, *args, **kwargs):
            trim_kwargs_from_test_function(test_fn, kwargs)
            loop = asyncio.get_event_loop()
            return loop.run_until_complete(
                test_fn(test_class_instance, **kwargs))

        return run
示例#7
0
class ScenarioTest(ReplayableTest, CheckerMixin, unittest.TestCase):
    def __init__(self, method_name, config_file=None, recording_name=None,
                 recording_processors=None, replay_processors=None, recording_patches=None, replay_patches=None):
        self.cli_ctx = get_dummy_cli()
        self.name_replacer = GeneralNameReplacer()
        self.kwargs = {}
        self.test_guid_count = 0
        self._processors_to_reset = [StorageAccountKeyReplacer()]
        default_recording_processors = [
            SubscriptionRecordingProcessor(MOCKED_SUBSCRIPTION_ID),
            OAuthRequestResponsesFilter(),
            LargeRequestBodyProcessor(),
            LargeResponseBodyProcessor(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer(),
            self.name_replacer
        ] + self._processors_to_reset

        default_replay_processors = [
            LargeResponseBodyReplacer(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer(),
        ]

        default_recording_patches = [patch_main_exception_handler]

        default_replay_patches = [
            patch_main_exception_handler,
            patch_time_sleep_api,
            patch_long_run_operation_delay,
            patch_load_cached_subscriptions,
            patch_retrieve_token_for_user,
            patch_progress_controller,
        ]

        def _merge_lists(base, patches):
            merged = list(base)
            if patches and not isinstance(patches, list):
                patches = [patches]
            if patches:
                merged = list(set(merged).union(set(patches)))
            return merged

        super(ScenarioTest, self).__init__(
            method_name,
            config_file=config_file,
            recording_processors=_merge_lists(default_recording_processors, recording_processors),
            replay_processors=_merge_lists(default_replay_processors, replay_processors),
            recording_patches=_merge_lists(default_recording_patches, recording_patches),
            replay_patches=_merge_lists(default_replay_patches, replay_patches),
            recording_dir=find_recording_dir(inspect.getfile(self.__class__)),
            recording_name=recording_name
        )

    def tearDown(self):
        for processor in self._processors_to_reset:
            processor.reset()
        super(ScenarioTest, self).tearDown()

    def create_random_name(self, prefix, length):
        self.test_resources_count += 1
        moniker = '{}{:06}'.format(prefix, self.test_resources_count)

        if self.in_recording:
            name = create_random_name(prefix, length)
            self.name_replacer.register_name_pair(name, moniker)
            return name

        return moniker

    # Use this helper to make playback work when guids are created and used in request urls, e.g. role assignment or AAD
    # service principals. For usages, in test code, patch the "guid-gen" routine to this one, e.g.
    # with mock.patch('azure.cli.command_modules.role.custom._gen_guid', side_effect=self.create_guid)
    def create_guid(self):
        import uuid
        self.test_guid_count += 1
        moniker = '88888888-0000-0000-0000-00000000' + ("%0.4X" % self.test_guid_count)

        if self.in_recording:
            name = uuid.uuid4()
            self.name_replacer.register_name_pair(str(name), moniker)
            return name

        return uuid.UUID(moniker)

    def cmd(self, command, checks=None, expect_failure=False):
        command = self._apply_kwargs(command)
        return execute(self.cli_ctx, command, expect_failure=expect_failure).assert_with_checks(checks)

    def get_subscription_id(self):
        if self.in_recording or self.is_live:
            subscription_id = self.cmd('account list --query "[?isDefault].id" -o tsv').output.strip()
        else:
            subscription_id = MOCKED_SUBSCRIPTION_ID
        return subscription_id
class AzureMgmtTestCase(ReplayableTest):
    def __init__(self, method_name, config_file=None,
                 recording_dir=None, recording_name=None,
                 recording_processors=None, replay_processors=None,
                 recording_patches=None, replay_patches=None):
        self.working_folder = os.path.dirname(__file__)
        self.qualified_test_name = get_qualified_method_name(self, method_name)
        self._fake_settings, self._real_settings = self._load_settings()
        self.region = 'westus'
        self.scrubber = GeneralNameReplacer()
        config_file = config_file or os.path.join(self.working_folder, TEST_SETTING_FILENAME)
        if not os.path.exists(config_file):
            config_file = None
        super(AzureMgmtTestCase, self).__init__(
            method_name,
            config_file=config_file,
            recording_dir=recording_dir,
            recording_name=recording_name or self.qualified_test_name,
            recording_processors=recording_processors or self._get_recording_processors(),
            replay_processors=replay_processors or self._get_replay_processors(),
            recording_patches=recording_patches,
            replay_patches=replay_patches,
        )

    @property
    def settings(self):
        if self.is_live:
            if self._real_settings:
                return self._real_settings
            else:
                raise AzureTestError('Need a mgmt_settings_real.py file to run tests live.')
        else:
            return self._fake_settings

    def _load_settings(self):
        try:
            from . import mgmt_settings_real as real_settings
            return fake_settings, real_settings
        except ImportError:
            return fake_settings, None

    def _get_recording_processors(self):
        return [
            self.scrubber,
            OAuthRequestResponsesFilter(),
            # DeploymentNameReplacer(), Not use this one, give me full control on deployment name
            RequestUrlNormalizer()
        ]

    def _get_replay_processors(self):
        return [
            RequestUrlNormalizer()
        ]

    def is_playback(self):
        return not self.is_live

    def _setup_scrubber(self):
        constants_to_scrub = ['SUBSCRIPTION_ID', 'AD_DOMAIN', 'TENANT_ID', 'CLIENT_OID', 'ADLA_JOB_ID']
        for key in constants_to_scrub:
            if hasattr(self.settings, key) and hasattr(self._fake_settings, key):
                self.scrubber.register_name_pair(getattr(self.settings, key),
                                                 getattr(self._fake_settings, key))

    def setUp(self):
        # Every test uses a different resource group name calculated from its
        # qualified test name.
        #
        # When running all tests serially, this allows us to delete
        # the resource group in teardown without waiting for the delete to
        # complete. The next test in line will use a different resource group,
        # so it won't have any trouble creating its resource group even if the
        # previous test resource group hasn't finished deleting.
        #
        # When running tests individually, if you try to run the same test
        # multiple times in a row, it's possible that the delete in the previous
        # teardown hasn't completed yet (because we don't wait), and that
        # would make resource group creation fail.
        # To avoid that, we also delete the resource group in the
        # setup, and we wait for that delete to complete.
        self._setup_scrubber()
        super(AzureMgmtTestCase, self).setUp()

    def tearDown(self):
        return super(AzureMgmtTestCase, self).tearDown()

    def create_basic_client(self, client_class, **kwargs):
        # Whatever the client, if credentials is None, fail
        with self.assertRaises(ValueError):
            client = client_class(
                credentials=None,
                **kwargs
            )

        # Real client creation
        client = client_class(
            credentials=self.settings.get_credentials(),
            **kwargs
        )
        if self.is_playback():
            client.config.long_running_operation_timeout = 0
        client.config.enable_http_logger = True
        return client

    def create_mgmt_client(self, client_class, **kwargs):
        # Whatever the client, if subscription_id is None, fail
        with self.assertRaises(ValueError):
            self.create_basic_client(
                client_class,
                subscription_id=None,
                **kwargs
            )

        return self.create_basic_client(
            client_class,
            subscription_id=self.settings.SUBSCRIPTION_ID,
            **kwargs
        )

    def create_random_name(self, name):
        return get_resource_name(name, self.qualified_test_name.encode())

    def get_resource_name(self, name):
        """Alias to create_random_name for back compatibility."""
        return self.create_random_name(name)

    def get_preparer_resource_name(self, prefix):
        """Random name generation for use by preparers.

        If prefix is a blank string, use the fully qualified test name instead.
        This is what legacy tests do for resource groups."""
        return self.get_resource_name(prefix or self.qualified_test_name.replace('.', '_'))
示例#9
0
class ScenarioTest(ReplayableTest, CheckerMixin, unittest.TestCase):
    def __init__(self, method_name, config_file=None, recording_name=None,
                 recording_processors=None, replay_processors=None, recording_patches=None, replay_patches=None):
        self.cli_ctx = get_dummy_cli()
        self.name_replacer = GeneralNameReplacer()
        self.kwargs = {}
        self.test_guid_count = 0
        self._processors_to_reset = [StorageAccountKeyReplacer()]
        default_recording_processors = [
            SubscriptionRecordingProcessor(MOCKED_SUBSCRIPTION_ID),
            OAuthRequestResponsesFilter(),
            LargeRequestBodyProcessor(),
            LargeResponseBodyProcessor(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer(),
            self.name_replacer
        ] + self._processors_to_reset

        default_replay_processors = [
            LargeResponseBodyReplacer(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer(),
        ]

        default_recording_patches = [patch_main_exception_handler]

        default_replay_patches = [
            patch_main_exception_handler,
            patch_time_sleep_api,
            patch_long_run_operation_delay,
            patch_load_cached_subscriptions,
            patch_retrieve_token_for_user,
            patch_progress_controller,
        ]

        def _merge_lists(base, patches):
            merged = list(base)
            if patches and not isinstance(patches, list):
                patches = [patches]
            if patches:
                merged = list(set(merged).union(set(patches)))
            return merged

        super(ScenarioTest, self).__init__(
            method_name,
            config_file=config_file,
            recording_processors=_merge_lists(default_recording_processors, recording_processors),
            replay_processors=_merge_lists(default_replay_processors, replay_processors),
            recording_patches=_merge_lists(default_recording_patches, recording_patches),
            replay_patches=_merge_lists(default_replay_patches, replay_patches),
            recording_dir=find_recording_dir(inspect.getfile(self.__class__)),
            recording_name=recording_name
        )

    def tearDown(self):
        for processor in self._processors_to_reset:
            processor.reset()
        super(ScenarioTest, self).tearDown()

    def create_random_name(self, prefix, length):
        self.test_resources_count += 1
        moniker = '{}{:06}'.format(prefix, self.test_resources_count)

        if self.in_recording:
            name = create_random_name(prefix, length)
            self.name_replacer.register_name_pair(name, moniker)
            return name

        return moniker

    # Use this helper to make playback work when guids are created and used in request urls, e.g. role assignment or AAD
    # service principals. For usages, in test code, patch the "guid-gen" routine to this one, e.g.
    # with mock.patch('azure.cli.command_modules.role.custom._gen_guid', side_effect=self.create_guid)
    def create_guid(self):
        import uuid
        self.test_guid_count += 1
        moniker = '88888888-0000-0000-0000-00000000' + ("%0.4X" % self.test_guid_count)

        if self.in_recording:
            name = uuid.uuid4()
            self.name_replacer.register_name_pair(str(name), moniker)
            return name

        return uuid.UUID(moniker)

    def cmd(self, command, checks=None, expect_failure=False):
        command = self._apply_kwargs(command)
        return execute(self.cli_ctx, command, expect_failure=expect_failure).assert_with_checks(checks)

    def get_subscription_id(self):
        if self.in_recording or self.is_live:
            subscription_id = self.cmd('account list --query "[?isDefault].id" -o tsv').output.strip()
        else:
            subscription_id = MOCKED_SUBSCRIPTION_ID
        return subscription_id
示例#10
0
文件: base.py 项目: trgrie/azure-cli
class ScenarioTest(ReplayableTest, CheckerMixin, unittest.TestCase):
    def __init__(self,
                 method_name,
                 config_file=None,
                 recording_name=None,
                 recording_processors=None,
                 replay_processors=None,
                 recording_patches=None,
                 replay_patches=None):
        from azure.cli.testsdk import TestCli
        self.cli_ctx = TestCli()
        self.name_replacer = GeneralNameReplacer()
        self.kwargs = {}

        default_recording_processors = [
            SubscriptionRecordingProcessor(MOCKED_SUBSCRIPTION_ID),
            OAuthRequestResponsesFilter(),
            LargeRequestBodyProcessor(),
            LargeResponseBodyProcessor(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer(), self.name_replacer
        ]

        default_replay_processors = [
            LargeResponseBodyReplacer(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer(),
        ]

        default_recording_patches = [patch_main_exception_handler]

        default_replay_patches = [
            patch_main_exception_handler,
            patch_time_sleep_api,
            patch_long_run_operation_delay,
            patch_load_cached_subscriptions,
            patch_retrieve_token_for_user,
            patch_progress_controller,
        ]

        def _merge_lists(base, patches):
            merged = list(base)
            if patches and not isinstance(patches, list):
                patches = [patches]
            if patches:
                merged = list(set(merged).union(set(patches)))
            return merged

        super(ScenarioTest, self).__init__(
            method_name,
            config_file=config_file,
            recording_processors=_merge_lists(default_recording_processors,
                                              recording_processors),
            replay_processors=_merge_lists(default_replay_processors,
                                           replay_processors),
            recording_patches=_merge_lists(default_recording_patches,
                                           recording_patches),
            replay_patches=_merge_lists(default_replay_patches,
                                        replay_patches),
            recording_dir=find_recording_dir(self.cli_ctx,
                                             inspect.getfile(self.__class__)),
            recording_name=recording_name)

    def create_random_name(self, prefix, length):
        self.test_resources_count += 1
        moniker = '{}{:06}'.format(prefix, self.test_resources_count)

        if self.in_recording:
            name = create_random_name(prefix, length)
            self.name_replacer.register_name_pair(name, moniker)
            return name

        return moniker

    def cmd(self, command, checks=None, expect_failure=False):
        try:
            command = command.format(**self.kwargs)
        except KeyError:
            pass
        return execute(
            self.cli_ctx, command,
            expect_failure=expect_failure).assert_with_checks(checks)

    def get_subscription_id(self):
        if self.in_recording or self.is_live:
            subscription_id = self.cmd(
                'account list --query "[?isDefault].id" -o tsv').output.strip(
                )
        else:
            subscription_id = MOCKED_SUBSCRIPTION_ID
        return subscription_id
示例#11
0
class AzureMgmtTestCase(ReplayableTest):
    def __init__(self,
                 method_name,
                 config_file=None,
                 recording_dir=None,
                 recording_name=None,
                 recording_processors=None,
                 replay_processors=None,
                 recording_patches=None,
                 replay_patches=None):
        self.working_folder = os.path.dirname(__file__)
        self.qualified_test_name = get_qualified_method_name(self, method_name)
        self._fake_settings, self._real_settings = self._load_settings()
        self.region = 'westus'
        self.scrubber = GeneralNameReplacer()
        config_file = config_file or os.path.join(self.working_folder,
                                                  TEST_SETTING_FILENAME)
        if not os.path.exists(config_file):
            config_file = None
        super(AzureMgmtTestCase, self).__init__(
            method_name,
            config_file=config_file,
            recording_dir=recording_dir,
            recording_name=recording_name or self.qualified_test_name,
            recording_processors=recording_processors
            or self._get_recording_processors(),
            replay_processors=replay_processors
            or self._get_replay_processors(),
            recording_patches=recording_patches,
            replay_patches=replay_patches,
        )

    @property
    def settings(self):
        if self.is_live:
            if self._real_settings:
                return self._real_settings
            else:
                raise AzureTestError(
                    'Need a mgmt_settings_real.py file to run tests live.')
        else:
            return self._fake_settings

    def _load_settings(self):
        try:
            from . import mgmt_settings_real as real_settings
            return fake_settings, real_settings
        except ImportError:
            return fake_settings, None

    def _get_recording_processors(self):
        return [
            self.scrubber,
            OAuthRequestResponsesFilter(),
            DeploymentNameReplacer(),
            RequestUrlNormalizer()
        ]

    def _get_replay_processors(self):
        return [RequestUrlNormalizer()]

    def is_playback(self):
        return not self.is_live

    def _setup_scrubber(self):
        constants_to_scrub = [
            'SUBSCRIPTION_ID', 'AD_DOMAIN', 'TENANT_ID', 'CLIENT_OID'
        ]
        for key in constants_to_scrub:
            if hasattr(self.settings, key) and hasattr(self._fake_settings,
                                                       key):
                self.scrubber.register_name_pair(
                    getattr(self.settings, key),
                    getattr(self._fake_settings, key))

    def setUp(self):
        # Every test uses a different resource group name calculated from its
        # qualified test name.
        #
        # When running all tests serially, this allows us to delete
        # the resource group in teardown without waiting for the delete to
        # complete. The next test in line will use a different resource group,
        # so it won't have any trouble creating its resource group even if the
        # previous test resource group hasn't finished deleting.
        #
        # When running tests individually, if you try to run the same test
        # multiple times in a row, it's possible that the delete in the previous
        # teardown hasn't completed yet (because we don't wait), and that
        # would make resource group creation fail.
        # To avoid that, we also delete the resource group in the
        # setup, and we wait for that delete to complete.
        self._setup_scrubber()
        super(AzureMgmtTestCase, self).setUp()

    def tearDown(self):
        return super(AzureMgmtTestCase, self).tearDown()

    def create_basic_client(self, client_class, **kwargs):
        # Whatever the client, if credentials is None, fail
        with self.assertRaises(ValueError):
            client = client_class(credentials=None, **kwargs)

        # Real client creation
        client = client_class(credentials=self.settings.get_credentials(),
                              **kwargs)
        if self.is_playback():
            client.config.long_running_operation_timeout = 0
        return client

    def create_mgmt_client(self, client_class, **kwargs):
        # Whatever the client, if subscription_id is None, fail
        with self.assertRaises(ValueError):
            self.create_basic_client(client_class,
                                     subscription_id=None,
                                     **kwargs)

        return self.create_basic_client(
            client_class,
            subscription_id=self.settings.SUBSCRIPTION_ID,
            **kwargs)

    def create_random_name(self, name):
        return get_resource_name(name, self.qualified_test_name.encode())

    def get_resource_name(self, name):
        """Alias to create_random_name for back compatibility."""
        return self.create_random_name(name)

    def get_preparer_resource_name(self, prefix):
        """Random name generation for use by preparers.

        If prefix is a blank string, use the fully qualified test name instead.
        This is what legacy tests do for resource groups."""
        return self.get_resource_name(
            prefix or self.qualified_test_name.replace('.', '_'))
示例#12
0
class AzureMgmtTestCase(AzureTestCase):
    def __init__(self,
                 method_name,
                 config_file=None,
                 recording_dir=None,
                 recording_name=None,
                 recording_processors=None,
                 replay_processors=None,
                 recording_patches=None,
                 replay_patches=None):
        self.working_folder = os.path.dirname(__file__)
        self.qualified_test_name = get_qualified_method_name(self, method_name)
        self._fake_settings, self._real_settings = self._load_settings()
        self.region = 'westus'
        self.scrubber = GeneralNameReplacer()
        config_file = config_file or os.path.join(self.working_folder,
                                                  TEST_SETTING_FILENAME)
        if not os.path.exists(config_file):
            config_file = None
        super(AzureMgmtTestCase, self).__init__(
            method_name,
            config_file=config_file,
            recording_dir=recording_dir,
            recording_name=recording_name or self.qualified_test_name,
            recording_processors=recording_processors
            or self._get_recording_processors(),
            replay_processors=replay_processors
            or self._get_replay_processors(),
            recording_patches=recording_patches,
            replay_patches=replay_patches,
        )

    def _setup_scrubber(self):
        constants_to_scrub = [
            'SUBSCRIPTION_ID', 'AD_DOMAIN', 'TENANT_ID', 'CLIENT_OID',
            'ADLA_JOB_ID'
        ]
        for key in constants_to_scrub:
            if hasattr(self.settings, key) and hasattr(self._fake_settings,
                                                       key):
                self.scrubber.register_name_pair(
                    getattr(self.settings, key),
                    getattr(self._fake_settings, key))

    def setUp(self):
        # Every test uses a different resource group name calculated from its
        # qualified test name.
        #
        # When running all tests serially, this allows us to delete
        # the resource group in teardown without waiting for the delete to
        # complete. The next test in line will use a different resource group,
        # so it won't have any trouble creating its resource group even if the
        # previous test resource group hasn't finished deleting.
        #
        # When running tests individually, if you try to run the same test
        # multiple times in a row, it's possible that the delete in the previous
        # teardown hasn't completed yet (because we don't wait), and that
        # would make resource group creation fail.
        # To avoid that, we also delete the resource group in the
        # setup, and we wait for that delete to complete.
        self._setup_scrubber()
        super(AzureMgmtTestCase, self).setUp()

    def tearDown(self):
        return super(AzureMgmtTestCase, self).tearDown()

    def create_mgmt_client(self, client_class, **kwargs):
        # Whatever the client, if subscription_id is None, fail
        with self.assertRaises(ValueError):
            self.create_basic_client(client_class,
                                     subscription_id=None,
                                     **kwargs)

        return self.create_basic_client(
            client_class,
            subscription_id=self.settings.SUBSCRIPTION_ID,
            **kwargs)