def _decrypt_and_deserialize_entity(entity, property_resolver,
                                    require_encryption, key_encryption_key,
                                    key_resolver):
    try:
        _validate_decryption_required(require_encryption, key_encryption_key,
                                      key_resolver)
        entity_iv, encrypted_properties, content_encryption_key, isJavaV1 = None, None, None, False
        if (key_encryption_key is not None) or (key_resolver is not None):
            entity_iv, encrypted_properties, content_encryption_key, isJavaV1 = \
                _extract_encryption_metadata(entity, require_encryption, key_encryption_key, key_resolver)
    except:
        raise AzureException(_ERROR_DECRYPTION_FAILURE)

    entity = _convert_json_to_entity(entity, property_resolver,
                                     encrypted_properties)

    if entity_iv is not None and encrypted_properties is not None and \
                    content_encryption_key is not None:
        try:
            entity = _decrypt_entity(entity, encrypted_properties,
                                     content_encryption_key, entity_iv,
                                     isJavaV1)
        except:
            raise AzureException(_ERROR_DECRYPTION_FAILURE)

    return entity
Beispiel #2
0
def _parse_blob(response, name, snapshot, validate_content=False, require_encryption=False,
                key_encryption_key=None, key_resolver_function=None, start_offset=None, end_offset=None):
    if response is None:
        return None

    metadata = _parse_metadata(response)
    props = _parse_properties(response, BlobProperties)

    # For range gets, only look at 'x-ms-blob-content-md5' for overall MD5
    content_settings = getattr(props, 'content_settings')
    if 'content-range' in response.headers:
        if 'x-ms-blob-content-md5' in response.headers:
            setattr(content_settings, 'content_md5', _to_str(response.headers['x-ms-blob-content-md5']))
        else:
            delattr(content_settings, 'content_md5')

    if validate_content:
        computed_md5 = _get_content_md5(response.body)
        _validate_content_match(response.headers['content-md5'], computed_md5)

    if key_encryption_key is not None or key_resolver_function is not None:
        try:
            response.body = _decrypt_blob(require_encryption, key_encryption_key, key_resolver_function,
                                          response, start_offset, end_offset)
        except:
            raise AzureException(_ERROR_DECRYPTION_FAILURE)

    return Blob(name, snapshot, response.body, props, metadata)
Beispiel #3
0
def _decrypt_queue_message(message, require_encryption, key_encryption_key, resolver):
    '''
    Returns the decrypted message contents from an EncryptedQueueMessage.
    If no encryption metadata is present, will return the unaltered message.
    :param str message:
        The JSON formatted QueueEncryptedMessage contents with all associated metadata.
    :param bool require_encryption:
        If set, will enforce that the retrieved messages are encrypted and decrypt them.
    :param object key_encryption_key:
        The user-provided key-encryption-key. Must implement the following methods:
        unwrap_key(key, algorithm)--returns the unwrapped form of the specified symmetric key using the string-specified algorithm.
        get_kid()--returns a string key id for this key-encryption-key.
    :param function resolver(kid):
        The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above.
    :return: The plain text message from the queue message.
    :rtype: str
    '''

    try:
        message = loads(message)

        encryption_data = _dict_to_encryption_data(message['EncryptionData'])
        decoded_data = _decode_base64_to_bytes(message['EncryptedMessageContents'])
    except (KeyError, ValueError):
        # Message was not json formatted and so was not encrypted
        # or the user provided a json formatted message.
        if require_encryption:
            raise ValueError(_ERROR_MESSAGE_NOT_ENCRYPTED)
        else:
            return message
    try:
        return _decrypt(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8')
    except Exception:
        raise AzureException(_ERROR_DECRYPTION_FAILURE)
    def _perform_request(self, request, encoding='utf-8'):
        '''
        Sends the request and return response. Catches HTTPError and hands it
        to error handler
        '''
        try:
            resp = self._filter(request)

            if sys.version_info >= (3,) and isinstance(resp, bytes) and \
                encoding:
                resp = resp.decode(encoding)

        # Parse and wrap HTTP errors in AzureHttpError which inherits from AzureException
        except HTTPError as ex:
            _storage_error_handler(ex)

        # Wrap all other exceptions as AzureExceptions to ease exception handling code
        except Exception as ex:
            if sys.version_info >= (3,):
                # Automatic chaining in Python 3 means we keep the trace
                raise AzureException
            else:
                # There isn't a good solution in 2 for keeping the stack trace 
                # in general, or that will not result in an error in 3
                # However, we can keep the previous error type and message
                # TODO: In the future we will log the trace
                raise AzureException('{}: {}'.format(ex.__class__.__name__, ex.args[0]))

        return resp
Beispiel #5
0
    def _perform_request(self, request, encoding='utf-8'):
        '''
        Sends the request and return response. Catches HTTPError and hands it
        to error handler
        '''
        try:
            response = self._filter(request)
        except Exception as ex:
            if sys.version_info >= (3, ):
                # Automatic chaining in Python 3 means we keep the trace
                raise AzureException
            else:
                # There isn't a good solution in 2 for keeping the stack trace
                # in general, or that will not result in an error in 3
                # However, we can keep the previous error type and message
                # TODO: In the future we will log the trace
                raise AzureException('{}: {}'.format(ex.__class__.__name__,
                                                     ex.args[0]))

        if self.response_callback:
            self.response_callback(response)

        # Parse and wrap HTTP errors in AzureHttpError which inherits from AzureException
        if response.status >= 300:
            # This exception will be caught by the general error handler
            # and raised as an azure http exception
            _storage_error_handler(
                HTTPError(response.status, response.message, response.headers,
                          response.body))

        return response
 def _validate_echoed_client_request_id(request, response):
     # raise exception if the echoed client request id from the service is not identical to the one we sent
     if _CLIENT_REQUEST_ID_HEADER_NAME in response.headers and \
             request.headers[_CLIENT_REQUEST_ID_HEADER_NAME] != response.headers[_CLIENT_REQUEST_ID_HEADER_NAME]:
         raise AzureException(
             "Echoed client request ID: {} does not match sent client request ID: {}.  Service request ID: {}"
             .format(response.headers[_CLIENT_REQUEST_ID_HEADER_NAME],
                     request.headers[_CLIENT_REQUEST_ID_HEADER_NAME],
                     response.headers['x-ms-request-id']))
Beispiel #7
0
def _decrypt_entity(entity, encrypted_properties_list, content_encryption_key,
                    entityIV, isJavaV1):
    '''
    Decrypts the specified entity using AES256 in CBC mode with 128 bit padding. Unwraps the CEK 
    using either the specified KEK or the key returned by the key_resolver. Properties 
    specified in the encrypted_properties_list, will be decrypted and decoded to utf-8 strings.

    :param entity:
        The entity being retrieved and decrypted. Could be a dict or an entity object.
    :param list encrypted_properties_list:
        The encrypted list of all the properties that are encrypted.
    :param bytes[] content_encryption_key:
        The key used internally to encrypt the entity. Extrated from the entity metadata.
    :param bytes[] entityIV:
        The intialization vector used to seed the encryption algorithm. Extracted from the
        entity metadata.
    :return: The decrypted entity
    :rtype: Entity
    '''

    _validate_not_none('entity', entity)

    decrypted_entity = deepcopy(entity)
    try:
        for property in entity.keys():
            if property in encrypted_properties_list:
                value = entity[property]

                propertyIV = _generate_property_iv(entityIV,
                                                   entity['PartitionKey'],
                                                   entity['RowKey'], property,
                                                   isJavaV1)
                cipher = _generate_AES_CBC_cipher(content_encryption_key,
                                                  propertyIV)

                # Decrypt the property.
                decryptor = cipher.decryptor()
                decrypted_data = (decryptor.update(value.value) +
                                  decryptor.finalize())

                # Unpad the data.
                unpadder = PKCS7(128).unpadder()
                decrypted_data = (unpadder.update(decrypted_data) +
                                  unpadder.finalize())

                decrypted_data = decrypted_data.decode('utf-8')

                decrypted_entity[property] = decrypted_data

        decrypted_entity.pop('_ClientEncryptionMetadata1')
        decrypted_entity.pop('_ClientEncryptionMetadata2')
        return decrypted_entity
    except:
        raise AzureException(_ERROR_DECRYPTION_FAILURE)
    def _perform_request(self,
                         request,
                         parser=None,
                         parser_args=None,
                         operation_context=None,
                         expected_errors=None):
        '''
        Sends the request and return response. Catches HTTPError and hands it
        to error handler
        '''
        operation_context = operation_context or _OperationContext()
        retry_context = RetryContext()
        retry_context.is_emulated = self.is_emulated

        # if request body is a stream, we need to remember its current position in case retries happen
        if hasattr(request.body, 'read'):
            try:
                retry_context.body_position = request.body.tell()
            except (AttributeError, UnsupportedOperation):
                # if body position cannot be obtained, then retries will not work
                pass

        # Apply the appropriate host based on the location mode
        self._apply_host(request, operation_context, retry_context)

        # Apply common settings to the request
        _update_request(request, self._X_MS_VERSION, self._USER_AGENT_STRING)
        client_request_id_prefix = str.format(
            "Client-Request-ID={0}", request.headers['x-ms-client-request-id'])

        while True:
            try:
                try:
                    # Execute the request callback
                    if self.request_callback:
                        self.request_callback(request)

                    # Add date and auth after the callback so date doesn't get too old and
                    # authentication is still correct if signed headers are added in the request
                    # callback. This also ensures retry policies with long back offs
                    # will work as it resets the time sensitive headers.
                    _add_date_header(request)

                    try:
                        # request can be signed individually
                        self.authentication.sign_request(request)
                    except AttributeError:
                        # session can also be signed
                        self.request_session = self.authentication.signed_session(
                            self.request_session)

                    # Set the request context
                    retry_context.request = request

                    # Log the request before it goes out
                    logger.info(
                        "%s Outgoing request: Method=%s, Path=%s, Query=%s, Headers=%s.",
                        client_request_id_prefix, request.method, request.path,
                        request.query,
                        str(request.headers).replace('\n', ''))

                    # Perform the request
                    response = self._httpclient.perform_request(request)

                    # Execute the response callback
                    if self.response_callback:
                        self.response_callback(response)

                    # Set the response context
                    retry_context.response = response

                    # Log the response when it comes back
                    logger.info(
                        "%s Receiving Response: "
                        "%s, HTTP Status Code=%s, Message=%s, Headers=%s.",
                        client_request_id_prefix,
                        self.extract_date_and_request_id(retry_context),
                        response.status, response.message,
                        str(response.headers).replace('\n', ''))

                    # Parse and wrap HTTP errors in AzureHttpError which inherits from AzureException
                    if response.status >= 300:
                        # This exception will be caught by the general error handler
                        # and raised as an azure http exception
                        _http_error_handler(
                            HTTPError(response.status, response.message,
                                      response.headers, response.body))

                    # Parse the response
                    if parser:
                        if parser_args:
                            args = [response]
                            args.extend(parser_args)
                            return parser(*args)
                        else:
                            return parser(response)
                    else:
                        return
                except AzureException as ex:
                    retry_context.exception = ex
                    raise ex
                except Exception as ex:
                    retry_context.exception = ex
                    if sys.version_info >= (3, ):
                        # Automatic chaining in Python 3 means we keep the trace
                        raise AzureException(ex.args[0])
                    else:
                        # There isn't a good solution in 2 for keeping the stack trace
                        # in general, or that will not result in an error in 3
                        # However, we can keep the previous error type and message
                        # TODO: In the future we will log the trace
                        msg = ""
                        if len(ex.args) > 0:
                            msg = ex.args[0]
                        raise AzureException('{}: {}'.format(
                            ex.__class__.__name__, msg))

            except AzureException as ex:
                # only parse the strings used for logging if logging is at least enabled for CRITICAL
                if logger.isEnabledFor(logging.CRITICAL):
                    exception_str_in_one_line = str(ex).replace('\n', '')
                    status_code = retry_context.response.status if retry_context.response is not None else 'Unknown'
                    timestamp_and_request_id = self.extract_date_and_request_id(
                        retry_context)

                # if the http error was expected, we should short-circuit
                if isinstance(
                        ex, AzureHttpError
                ) and expected_errors is not None and ex.error_code in expected_errors:
                    logger.info(
                        "%s Received expected http error: "
                        "%s, HTTP status code=%s, Exception=%s.",
                        client_request_id_prefix, timestamp_and_request_id,
                        status_code, exception_str_in_one_line)
                    raise ex

                logger.info(
                    "%s Operation failed: checking if the operation should be retried. "
                    "Current retry count=%s, %s, HTTP status code=%s, Exception=%s.",
                    client_request_id_prefix, retry_context.count if hasattr(
                        retry_context, 'count') else 0,
                    timestamp_and_request_id, status_code,
                    exception_str_in_one_line)

                # Decryption failures (invalid objects, invalid algorithms, data unencrypted in strict mode, etc)
                # will not be resolved with retries.
                if str(ex) == _ERROR_DECRYPTION_FAILURE:
                    logger.error(
                        "%s Encountered decryption failure: this cannot be retried. "
                        "%s, HTTP status code=%s, Exception=%s.",
                        client_request_id_prefix, timestamp_and_request_id,
                        status_code, exception_str_in_one_line)
                    raise ex

                # Determine whether a retry should be performed and if so, how
                # long to wait before performing retry.
                retry_interval = self.retry(retry_context)
                if retry_interval is not None:
                    # Execute the callback
                    if self.retry_callback:
                        self.retry_callback(retry_context)

                    logger.info(
                        "%s Retry policy is allowing a retry: Retry count=%s, Interval=%s.",
                        client_request_id_prefix, retry_context.count,
                        retry_interval)

                    # Sleep for the desired retry interval
                    sleep(retry_interval)
                else:
                    logger.error(
                        "%s Retry policy did not allow for a retry: "
                        "%s, HTTP status code=%s, Exception=%s.",
                        client_request_id_prefix, timestamp_and_request_id,
                        status_code, exception_str_in_one_line)
                    raise ex
            finally:
                # If this is a location locked operation and the location is not set,
                # this is the first request of that operation. Set the location to
                # be used for subsequent requests in the operation.
                if operation_context.location_lock and not operation_context.host_location:
                    # note: to cover the emulator scenario, the host_location is grabbed
                    # from request.host_locations(which includes the dev account name)
                    # instead of request.host(which at this point no longer includes the dev account name)
                    operation_context.host_location = {
                        retry_context.location_mode:
                        request.host_locations[retry_context.location_mode]
                    }
    def _perform_request(self, request, parser=None, parser_args=None, operation_context=None):
        '''
        Sends the request and return response. Catches HTTPError and hands it
        to error handler
        '''
        operation_context = operation_context or _OperationContext()
        retry_context = RetryContext()

        # Apply the appropriate host based on the location mode
        self._apply_host(request, operation_context, retry_context)

        # Apply common settings to the request
        _update_request(request)

        while(True):
            try:
                try:
                    # Execute the request callback 
                    if self.request_callback:
                        self.request_callback(request)

                    # Add date and auth after the callback so date doesn't get too old and 
                    # authentication is still correct if signed headers are added in the request 
                    # callback. This also ensures retry policies with long back offs 
                    # will work as it resets the time sensitive headers.
                    _add_date_header(request)
                    self.authentication.sign_request(request)

                    # Set the request context
                    retry_context.request = request

                    # Perform the request
                    response = self._httpclient.perform_request(request)

                    # Execute the response callback
                    if self.response_callback:
                        self.response_callback(response)

                    # Set the response context
                    retry_context.response = response

                    # Parse and wrap HTTP errors in AzureHttpError which inherits from AzureException
                    if response.status >= 300:
                        # This exception will be caught by the general error handler
                        # and raised as an azure http exception
                        _http_error_handler(HTTPError(response.status, response.message, response.headers, response.body))

                    # Parse the response
                    if parser:
                        if parser_args:
                            args = [response]
                            args.extend(parser_args)
                            return parser(*args)
                        else:
                            return parser(response)
                    else:
                        return
                except AzureException as ex:
                    raise ex
                except Exception as ex:
                    if sys.version_info >= (3,):
                        # Automatic chaining in Python 3 means we keep the trace
                        raise AzureException(ex.args[0])
                    else:
                        # There isn't a good solution in 2 for keeping the stack trace 
                        # in general, or that will not result in an error in 3
                        # However, we can keep the previous error type and message
                        # TODO: In the future we will log the trace
                        msg = ""
                        if len(ex.args) > 0:
                            msg = ex.args[0]
                        raise AzureException('{}: {}'.format(ex.__class__.__name__, msg))


            except AzureException as ex:
                # Decryption failures (invalid objects, invalid algorithms, data unencrypted in strict mode, etc)
                # will not be resolved with retries.
                if str(ex) == _ERROR_DECRYPTION_FAILURE:
                    raise ex
                # Determine whether a retry should be performed and if so, how 
                # long to wait before performing retry.
                retry_interval = self.retry(retry_context)
                if retry_interval is not None:
                    # Execute the callback
                    if self.retry_callback:
                        self.retry_callback(retry_context)

                    # Sleep for the desired retry interval
                    sleep(retry_interval)
                else:
                    raise ex
            finally:
                # If this is a location locked operation and the location is not set, 
                # this is the first request of that operation. Set the location to 
                # be used for subsequent requests in the operation.
                if operation_context.location_lock and not operation_context.host_location:
                    operation_context.host_location = {retry_context.location_mode: request.host}
def _convert_json_to_entity(entry_element, property_resolver,
                            encrypted_properties):
    ''' Convert json response to entity.

    The entity format is:
    {
       "Address":"Mountain View",
       "Age":23,
       "AmountDue":200.23,
       "*****@*****.**":"Edm.Guid",
       "CustomerCode":"c9da6455-213d-42c9-9a79-3e9149a57833",
       "*****@*****.**":"Edm.DateTime",
       "CustomerSince":"2008-07-10T00:00:00",
       "IsActive":true,
       "*****@*****.**":"Edm.Int64",
       "NumberOfOrders":"255",
       "PartitionKey":"mypartitionkey",
       "RowKey":"myrowkey"
    }
    '''
    entity = Entity()

    properties = {}
    edmtypes = {}
    odata = {}

    for name, value in entry_element.items():
        if name.startswith('odata.'):
            odata[name[6:]] = value
        elif name.endswith('@odata.type'):
            edmtypes[name[:-11]] = value
        else:
            properties[name] = value

    # Partition key is a known property
    partition_key = properties.pop('PartitionKey', None)
    if partition_key:
        entity['PartitionKey'] = partition_key

    # Row key is a known property
    row_key = properties.pop('RowKey', None)
    if row_key:
        entity['RowKey'] = row_key

    # Timestamp is a known property
    timestamp = properties.pop('Timestamp', None)
    if timestamp:
        entity['Timestamp'] = _from_entity_datetime(timestamp)

    for name, value in properties.items():
        mtype = edmtypes.get(name)

        # use the property resolver if present
        if property_resolver:
            # Clients are not expected to resolve these interal fields.
            # This check avoids unexpected behavior from the user-defined
            # property resolver.
            if not (name == '_ClientEncryptionMetadata1'
                    or name == '_ClientEncryptionMetadata2'):
                mtype = property_resolver(partition_key, row_key, name, value,
                                          mtype)

                # throw if the type returned is not a valid edm type
                if mtype and mtype not in _EDM_TYPES:
                    raise AzureException(
                        _ERROR_TYPE_NOT_SUPPORTED.format(mtype))

        # If the property was encrypted, supercede the results of the resolver and set as binary
        if encrypted_properties is not None and name in encrypted_properties:
            mtype = EdmType.BINARY

        # Add type for Int32
        if type(value) is int:
            mtype = EdmType.INT32

        # no type info, property should parse automatically
        if not mtype:
            entity[name] = value
        else:  # need an object to hold the property
            conv = _ENTITY_TO_PYTHON_CONVERSIONS.get(mtype)
            if conv is not None:
                try:
                    property = conv(value)
                except Exception as e:
                    # throw if the type returned by the property resolver
                    # cannot be used in the conversion
                    if property_resolver:
                        raise AzureException(
                            _ERROR_INVALID_PROPERTY_RESOLVER.format(
                                name, value, mtype))
                    else:
                        raise e
            else:
                property = EntityProperty(mtype, value)
            entity[name] = property

    # extract etag from entry
    etag = odata.get('etag')
    if timestamp:
        etag = 'W/"datetime\'' + url_quote(timestamp) + '\'"'
    entity['etag'] = etag

    return entity
Beispiel #11
0
def _validate_access_policies(identifiers):
    if identifiers and len(identifiers) > 5:
        raise AzureException(_ERROR_TOO_MANY_ACCESS_POLICIES)
Beispiel #12
0
def _validate_content_match(server_md5, computed_md5):
    if server_md5 != computed_md5:
        raise AzureException(
            _ERROR_MD5_MISMATCH.format(server_md5, computed_md5))
def throw_azure_exception(scope):
    """Raises azure exception."""
    raise AzureException()
Beispiel #14
0
class AzureProviderTestCase(TestCase):
    """Parent Class for AzureClientFactory test cases."""
    def test_name(self):
        """Test name property."""
        obj = AzureProvider()
        self.assertEqual(obj.name(), 'Azure')

    @patch('providers.azure.provider.AzureClientFactory')
    def test_cost_usage_source_is_reachable_valid(self, _):
        """Test that cost_usage_source_is_reachable succeeds."""
        credentials = {
            'subscription_id': FAKE.uuid4(),
            'tenant_id': FAKE.uuid4(),
            'client_id': FAKE.uuid4(),
            'client_secret': FAKE.word()
        }
        source_name = {
            'resource_group': FAKE.word(),
            'storage_account': FAKE.word()
        }
        obj = AzureProvider()
        self.assertTrue(
            obj.cost_usage_source_is_reachable(credentials, source_name))

    @patch('providers.azure.provider.AzureClientFactory',
           side_effect=AzureException('test exception'))
    def test_cost_usage_source_is_reachable_exception(self, _):
        """Test that ValidationError is raised when AzureException is raised."""
        credentials = {
            'subscription_id': FAKE.uuid4(),
            'tenant_id': FAKE.uuid4(),
            'client_id': FAKE.uuid4(),
            'client_secret': FAKE.word()
        }
        source_name = {
            'resource_group': FAKE.word(),
            'storage_account': FAKE.word()
        }
        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable(
                credentials, source_name)

    def test_cost_usage_source_is_reachable_badargs(self):
        """Test that a ValidationError is raised when no arguments are provided."""
        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable(None, None)

        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable(
                FAKE.word(), FAKE.word())

        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable({}, {})

    def test_infra_type_implementation(self):
        """Test that infra type returns None."""
        obj = AzureProvider()
        self.assertEqual(
            obj.infra_type_implementation(FAKE.word(), FAKE.word()), None)

    def test_infra_key_list_implementation(self):
        """Test that infra key list returns an empty list."""
        obj = AzureProvider()
        self.assertEqual(
            obj.infra_key_list_implementation(FAKE.uuid4(), FAKE.word()), [])
Beispiel #15
0
class AzureProviderTestCase(TestCase):
    """Parent Class for AzureClientFactory test cases."""
    def test_name(self):
        """Test name property."""
        obj = AzureProvider()
        self.assertEqual(obj.name(), "Azure")

    @patch("providers.azure.provider.AzureClientFactory")
    def test_cost_usage_source_is_reachable_valid(self, _):
        """Test that cost_usage_source_is_reachable succeeds."""
        credentials = {
            "subscription_id": FAKE.uuid4(),
            "tenant_id": FAKE.uuid4(),
            "client_id": FAKE.uuid4(),
            "client_secret": FAKE.word(),
        }
        source_name = {
            "resource_group": FAKE.word(),
            "storage_account": FAKE.word()
        }
        with patch("providers.azure.provider.AzureService") as MockHelper:
            MockHelper.return_value.describe_cost_management_exports.return_value = [
                "report1"
            ]
            obj = AzureProvider()
            self.assertTrue(
                obj.cost_usage_source_is_reachable(credentials, source_name))

    @patch("providers.azure.provider.AzureClientFactory",
           side_effect=AzureException("test exception"))
    def test_cost_usage_source_is_reachable_exception(self, _):
        """Test that ValidationError is raised when AzureException is raised."""
        credentials = {
            "subscription_id": FAKE.uuid4(),
            "tenant_id": FAKE.uuid4(),
            "client_id": FAKE.uuid4(),
            "client_secret": FAKE.word(),
        }
        source_name = {
            "resource_group": FAKE.word(),
            "storage_account": FAKE.word()
        }
        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable(
                credentials, source_name)

    def test_cost_usage_source_is_reachable_badargs(self):
        """Test that a ValidationError is raised when no arguments are provided."""
        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable(None, None)

        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable(
                FAKE.word(), FAKE.word())

        with self.assertRaises(ValidationError):
            AzureProvider().cost_usage_source_is_reachable({}, {})

    def test_infra_type_implementation(self):
        """Test that infra type returns None."""
        obj = AzureProvider()
        self.assertEqual(
            obj.infra_type_implementation(FAKE.word(), FAKE.word()), None)

    def test_infra_key_list_implementation(self):
        """Test that infra key list returns an empty list."""
        obj = AzureProvider()
        self.assertEqual(
            obj.infra_key_list_implementation(FAKE.uuid4(), FAKE.word()), [])

    @patch("providers.azure.provider.AzureClientFactory")
    def test_cost_usage_source_reachable_without_cost_export(self, _):
        """Test that cost_usage_source_is_reachable raises an exception when no cost reports exist."""
        credentials = {
            "subscription_id": FAKE.uuid4(),
            "tenant_id": FAKE.uuid4(),
            "client_id": FAKE.uuid4(),
            "client_secret": FAKE.word(),
        }
        source_name = {
            "resource_group": FAKE.word(),
            "storage_account": FAKE.word()
        }

        with patch("providers.azure.provider.AzureService") as MockHelper:
            MockHelper.return_value.describe_cost_management_exports.return_value = []
            azure_provider = AzureProvider()
            with self.assertRaises(ValidationError):
                azure_provider.cost_usage_source_is_reachable(
                    credentials, source_name)