예제 #1
0
def fetch_resources(query, resource_type, secrets, configuration):
    with auth(secrets) as cred:
        query = __create_resource_graph_query(query, resource_type,
                                              configuration)
        client = ResourceGraphClient(credentials=cred)
        resources = client.resources(query)

        results = to_dicts(resources.data, client.api_version)
        return results
예제 #2
0
    def __init__(self,
                 subscription_id=DEFAULT_SCOPE,
                 schema=SCHEMA_NAME,
                 models=MODELS,
                 update_field=UPD_FIELD_NAME):

        self.subscription_id = subscription_id
        self.client = ResourceGraphClient(credential=AZURE_CLIENT.credential,
                                          subscription_id=subscription_id)
        self.schema = schema
        self.models = models
        self.update_field = update_field
    def get_client(self, client_type):
        """
        Return resource management client type based on what's requested
        """

        try:
            # Check to see if credentials exist before returning an azure client object
            if self.credentials:
                # there is probably a better whay than a big case statement
                if client_type == "ComputeManagementClient":
                    return (ComputeManagementClient(self.credentials,
                                                    self.subscription_id))
                elif client_type == "ResourceGraphClient":
                    return (ResourceGraphClient(self.credentials,
                                                base_url=None))
                else:
                    raise NotImplementedError(
                        "No such client type {} supported".format(client_type))
            else:
                raise ServicePrincipalError(
                    "Missing or bad credentials for tenant {}".format(
                        self.tenant_name))

        except ServicePrincipalError as e:
            raise (e)
        except NotImplementedError as e:
            raise (e)
        except Exception as e:
            raise (e)
예제 #4
0
def init_resource_graph_client(secrets: Secrets) -> ResourceGraphClient:
    with auth(secrets) as cred:
        base_url = __get_cloud_env_by_name(
            secrets.get("azure_cloud")).endpoints.resource_manager
        client = ResourceGraphClient(credentials=cred, base_url=base_url)

        return client
예제 #5
0
def init_resource_graph_client(secrets: Secrets) -> ResourceGraphClient:
    with auth(secrets) as authentication:
        _base_url = authentication.cloud_environment.endpoints.resource_manager

        client = ResourceGraphClient(
            credentials=authentication,
            base_url=_base_url)

        return client
예제 #6
0
def init_resource_graph_client(
        experiment_secrets: Secrets) -> ResourceGraphClient:

    secrets = load_secrets(experiment_secrets)
    with auth(secrets) as authentication:
        base_url = secrets.get('cloud').endpoints.resource_manager
        client = ResourceGraphClient(credentials=authentication,
                                     base_url=base_url)

        return client
    def __init__(self,
                 tracer: logging.Logger,
                 subscriptionId: str,
                 authCredential: ManagedIdentityCredential):
        """Constructor

        Args:
            tracer (logging.Logger): Logger object.
            subscriptionId (str): Subsctiption Id in the context of which the queries would be run.
            authCredential (ManagedIdentityCredential): Credential of the managed identity which would be used for authorization.
        """
        self.tracer = tracer
        self.subscriptionId = subscriptionId
        self.authCredential = authCredential

        # Create Azure Resource Graph client.
        self.argClient = ResourceGraphClient(
            credential=self.authCredential,
            subscription_id=self.subscriptionId
        )
예제 #8
0
def main():

    # If "SUBSCRIPTION_ID" is not set in the environment variable, you need to set it manually: export SUBSCRIPTION_ID="{SUBSCRIPTION_ID}"
    SUBSCRIPTION_ID = os.environ.get("SUBSCRIPTION_ID", None)

    # Create client
    # For other authentication approaches, please see: https://pypi.org/project/azure-identity/
    resourcegraph_client = ResourceGraphClient(
        credential=DefaultAzureCredential(), subscription_id=SUBSCRIPTION_ID)

    # Basic query up to 2 pieces of data
    query = QueryRequest(query='project id, tags, properties | limit 2',
                         subscriptions=[SUBSCRIPTION_ID])
    query_response = resourcegraph_client.resources(query)
    print("Basic query up to 2 pieces of data:\n{}".format(query_response))

    # Basic query up to 2 pieces of object array
    query = QueryRequest(
        query='project id, tags, properties | limit 2',
        subscriptions=[SUBSCRIPTION_ID],
        options=QueryRequestOptions(result_format=ResultFormat.object_array))
    query_response = resourcegraph_client.resources(query)
    print("Basic query up to 2 pieces of object array:\n{}".format(
        query_response))

    # Query with options
    query = QueryRequest(query='project id',
                         subscriptions=[SUBSCRIPTION_ID],
                         options=QueryRequestOptions(top=4, skip=8))
    query_response = resourcegraph_client.resources(query)
    print("Query with options:\n{}".format(query_response))

    # Query with facet expressions
    facet_expression0 = 'location'
    facet_expression1 = 'nonExistingColumn'

    query = QueryRequest(
        query='project id, location | limit 10',
        subscriptions=[SUBSCRIPTION_ID],
        facets=[
            FacetRequest(expression=facet_expression0,
                         options=FacetRequestOptions(sort_order='desc',
                                                     top=1)),
            FacetRequest(expression=facet_expression1,
                         options=FacetRequestOptions(sort_order='desc', top=1))
        ])
    query_response = resourcegraph_client.resources(query)
    print("Query with facet expressions:\n{}".format(query_response))
class AzureResourceGraph(metaclass=Singleton):
    """Singleton class that provides access to Azure Resource Graph using the SDK."""

    logTag = "[AIOps][ARG]"

    def __init__(self,
                 tracer: logging.Logger,
                 subscriptionId: str,
                 authCredential: ManagedIdentityCredential):
        """Constructor

        Args:
            tracer (logging.Logger): Logger object.
            subscriptionId (str): Subsctiption Id in the context of which the queries would be run.
            authCredential (ManagedIdentityCredential): Credential of the managed identity which would be used for authorization.
        """
        self.tracer = tracer
        self.subscriptionId = subscriptionId
        self.authCredential = authCredential

        # Create Azure Resource Graph client.
        self.argClient = ResourceGraphClient(
            credential=self.authCredential,
            subscription_id=self.subscriptionId
        )

    def __customResponse(self, pipelineResponse: PipelineResponse, deserializedQueryResponse: QueryResponse, *kwargs) -> QueryResponse:
        """Extract the headers specific to throttling from the ARG HTTP response and set them as properties in the deserialized response.

        Args:
            pipelineResponse (PipelineResponse): HTTP response from ARG.
            deserializedQueryResponse (QueryResponse): Deserialized response from ARG SDK.

        Returns:
            QueryResponse: Deserialized response with the following additional properties set: quotaRemaining, quotaResetsAfter, statusCode.
        """
        self.tracer.info(
            "%s Extracting throttling headers from ARG response." % self.logTag)
        quotaRemaining = None
        quotaResetsAfter = None
        statusCode = None

        try:
            if pipelineResponse is None:
                errorMessage = "%s Pipeline response received from ARG SDK is None." % self.logTag
                raise Exception(errorMessage)
            
            headers = pipelineResponse.http_response.internal_response.headers
            self.tracer.info(
                "%s Headers from ARG SDK response = %s" % (self.logTag, headers))
            statusCode = pipelineResponse.http_response.status_code
            self.tracer.info(
                "%s Status code from ARG SDK pipeline response = %s" % (self.logTag, statusCode))
            quotaRemaining = int(
                headers._store[QUOTA_REMAINING_HEADER][1])
            quotaResetsAfter = self.__getSeconds(
                headers._store[QUOTA_RESETS_AFTER_HEADER][1])
        except Exception as e:
            self.tracer.error(
                "%s Could not extract the throttling headers from ARG response. (%s)", self.logTag, e, exc_info=True)
        finally:
            # If the header couldn't be extracted, set the default value.
            if quotaResetsAfter is None:
                quotaResetsAfter = DEFAULT_QUOTA_RESETS_AFTER

        # Adding additional properties to the response which can be used to handle throttling.
        deserializedQueryResponse.quotaRemaining = quotaRemaining
        deserializedQueryResponse.quotaResetsAfter = quotaResetsAfter
        deserializedQueryResponse.statusCode = statusCode

        return deserializedQueryResponse

    # Method that wraps the ARG Resources method and handles pagination as well as throttling.
    def getResources(self, subscriptionIds: List[str], query: str) -> List[Dict[str, str]]:
        """Wrapper around the ARG resources method. Handles pagination as well as throttling.

        Args:
            subscriptionIds (List[str]): List of subscriptions within which the resources should be queried.
            query (str): The query to be executed.

        Returns:
            List[Dict[str, str]]: List of resources, each resource being a dictionary of key-value pairs.
        """
        self.tracer.info("%s Getting resources using Azure Resource Graph for subscriptionIds=%s and query=%s." % (
            self.logTag, subscriptionIds, query))
        results = []

        # Guard clauses.
        self.__validateInputs(subscriptionIds, query)

        # Get the resources
        try:
            # First call to ARG.
            totalNumberOfResources = None
            self.tracer.info(
                "%s First request to ARG for resources." % self.logTag)
            argQueryResponse = self.__triggerArgResourcesMethod(
                query, subscriptionIds)
            totalNumberOfResources = argQueryResponse.total_records
            self.tracer.info(
                "%s Number of resources received = %s. numberOfResultsCompiledSoFar=%s; totalNumberOfResourcesExpected=%s; query=%s" % (self.logTag, len(argQueryResponse.data), len(results), totalNumberOfResources, query))
            results.extend(argQueryResponse.data)

            # If there are more than one page of results, use skip token to retrieve the subsequent pages.
            while argQueryResponse.skip_token is not None:
                self.tracer.info(
                    "%s Requesting for the next page of results from ARG." % self.logTag)
                argQueryResponse = self.__triggerArgResourcesMethod(
                    query, subscriptionIds, argQueryResponse.skip_token)
                self.tracer.info(
                    "%s Number of resources received = %s. numberOfResultsCompiledSoFar=%s; totalNumberOfResourcesExpected=%s; query=%s" % (self.logTag, len(argQueryResponse.data), len(results), totalNumberOfResources, query))
                results.extend(argQueryResponse.data)

            self.tracer.info(
                "%s Completed ARG call(s). totalNumberOfResultsCompiled= %s; totalNumberOfResourcesExpected=%s; query=%s" % (self.logTag, len(results), totalNumberOfResources, query))
            return results
        except Exception as e:
            self.tracer.error(
                "%s Could not get the resources using ARG. subscription=%s; query=%s; numberOfResultsCompiledSoFar=%s; totalNumberOfResourcesExpected(None if the first call itself failed)=%s.(%s)", self.logTag, subscriptionIds, query, len(results), totalNumberOfResources, e, exc_info=True)
            raise

    def __validateInputs(self, subscriptionIds, query):
        """Validate inputs passed to getResources.

        Args:
            subscriptionIds ([type]): List of subscription Ids passed to getResources.
            query ([type]): Query passed to getResources.

        Raises:
            ValueError: If subscriptionIds is None or empty. If query is empty.
            TypeError: If subscriptionIds is not of type list.
        """
        if subscriptionIds is None:
            raise ValueError(
                '%s subscriptionIds argument cannot be None.' % self.logTag)
        if type(subscriptionIds).__name__ != 'list':
            raise TypeError(
                '%s subscriptionIds argument should be of type list.' % self.logTag)
        if len(subscriptionIds) == 0:
            raise ValueError(
                '%s subscriptionIds argument should contain atleast one id.' % self.logTag)
        if not query:
            raise ValueError(
                '%s query argument cannot be empty.' % self.logTag)

    def __triggerArgResourcesMethod(self, query: str, subscriptionIds: List[str], skipToken: str = None) -> QueryResponse:
        """Trigger ARG Resources method.

        Args:
            query (str): Query to be run.
            subscriptionIds (List[str]): List of subscriptions within which the resources should be queried.
            skipToken (str, optional): Token from last run in case of pagination. Defaults to None.

        Raises:
            Exception: Maximum number of retries is exceeded.

        Returns:
            QueryResponse: Response from ARG.
        """

        # Build the query options.
        self.tracer.info(
            "%s Entered __triggerArgResourcesMethod" % self.logTag)
        self.tracer.info("%s Building query request options." % self.logTag)
        argQueryOptions = None
        if skipToken is None:
            argQueryOptions = QueryRequestOptions(
                result_format=RESULT_FORMAT)
        else:
            argQueryOptions = QueryRequestOptions(
                result_format=RESULT_FORMAT, skip_token=skipToken)

        # Build the query request.
        self.tracer.info("%s Building ARG query request." % self.logTag)
        argQuery = QueryRequest(
            query=query,
            subscriptions=subscriptionIds,
            options=argQueryOptions
        )

        # Call the ARG method in a loop to handle throttling.
        retries = 0
        while retries <= MAX_RETRIES:
            self.tracer.info(
                "%s Invoking resources method of ARG SDK." % self.logTag)
            argQueryResponse = None
            try:
                # Track latency of the SDK call.
                latencyStartTime = time.time()

                argQueryResponse = self.argClient.resources(
                    argQuery, cls=self.__customResponse)

                latency = TimeUtils.getElapsedMilliseconds(latencyStartTime)
            except Exception as e:
                latency = TimeUtils.getElapsedMilliseconds(latencyStartTime)
                self.tracer.error("%s ARG call failed. subscription=%s; query=%s; latency=%s ms. (e)",
                                  self.logTag, subscriptionIds, query, latency, e, exc_info=True)

            if self.__shouldRetry(argQueryResponse):
                self.tracer.info(
                    "%s Throttling limit is hit or failed to extract the status code header. ARG call took %s ms. subscription=%s; query=%s;" % (self.logTag, latency, subscriptionIds, query))
                self.__waitForQuotaReset(argQueryResponse.quotaResetsAfter)
                retries += 1
                continue

            self.tracer.info("%s Throttling limit not hit. subscription=%s; query=%s;" %
                             (self.logTag, subscriptionIds, query))
            self.tracer.info(
                "%s ARG response receieved. subscription=%s; query=%s; totalNumberOfResources=%s; numberOfResourcesInCurrentPage=%s; latency=%s ms." % (self.logTag, subscriptionIds, query, argQueryResponse.total_records, argQueryResponse.count, latency))
            return argQueryResponse

        # Max number of retries exceeded.
        if retries > MAX_RETRIES:
            errorMessage = "%s Maximum number of retries exceeded (MaxRetriesConfig=%s; CurrentRetries=%s). Aborting the ARG call. subscription=%s; query=%s" % (
                self.logTag, MAX_RETRIES, retries, subscriptionIds, query)
            self.tracer.error(errorMessage)
            raise Exception(errorMessage)

    # Check if the throttling limit has been hit for ARG calls
    def __shouldRetry(self, argQueryResponse: QueryResponse) -> bool:
        """Check if the ARG call should be retried.

        Args:
            argQueryResponse (QueryResponse): Response from ARG.

        Returns:
            bool: True if the retry limit has been hit or if the status code is None.
        """
        statusCode = argQueryResponse.statusCode
        self.tracer.info("%s Status code for the ARG call = %s" %
                         (self.logTag, str(statusCode)))
        # statusCode will be None when HTTP response header extraction failed
        if statusCode is None or statusCode == 429:
            return True
        return False

    def __waitForQuotaReset(self, quotaResetsAfter: int) -> None:
        """Wait for a random period to avoid bursting based on the quotaResetsAfter param passed.

        Args:
            quotaResetsAfter (int): The time after which the throttling quota resets.

        Raises:
            Exception: If quotaResetsAfter is greater than the maximum wait time.
        """
        # Do not wait if quotaResetsAfter is greater than the pre-configured time. Exit early.
        if quotaResetsAfter > MAX_QUOTA_RESETS_AFTER:
            errorMessage = "%s Quota will reset after %s seconds which is higher than the max wait time. Aborting the ARG call." % (
                self.logTag, quotaResetsAfter)
            self.tracer.error(errorMessage)
            raise Exception(errorMessage)
        delay = random.randint(1, 3) * quotaResetsAfter
        self.tracer.info("%s waiting for %s seconds" % (self.logTag, delay))
        time.sleep(delay)

    def __getSeconds(self, strTime: str) -> int:
        """Convert quota-resets-after header value from string to number of seconds.

        Args:
            strTime (str): Time in hh:mm:ss format.

        Returns:
            int: Number of seconds.
        """
        h, m, s = strTime.split(':')
        return int(h) * 3600 + int(m) * 60 + int(s)
예제 #10
0
class AzureRGConnector(GenericExtractor):
    def __init__(self,
                 subscription_id=DEFAULT_SCOPE,
                 schema=SCHEMA_NAME,
                 models=MODELS,
                 update_field=UPD_FIELD_NAME):

        self.subscription_id = subscription_id
        self.client = ResourceGraphClient(credential=AZURE_CLIENT.credential,
                                          subscription_id=subscription_id)
        self.schema = schema
        self.models = models
        self.update_field = update_field

    def get_count(self, model, search_domains=[]):

        queryStr = self.forge_query(model, count=True)

        query = QueryRequest(query=queryStr,
                             subscriptions=[self.subscription_id])
        query_response = self.client.resources(query)

        total_count = query_response.data[0]['Count']

        return total_count

    def read_query(self, model, search_domains=[], start_row=0):

        queryStr = self.forge_query(model)

        query = QueryRequest(query=queryStr,
                             subscriptions=[self.subscription_id])
        query_response = self.client.resources(query)
        print("Basic query :\n{}".format(query_response))

        return query_response.data

    def forge_query(self, model, page_size=PAGE_SIZE, count=False):

        class_scope = None
        if 'class' in model.keys():
            class_scope = model['class']
        else:
            class_scope = 'Resources'

        base_name = model['base_name']
        fieldnames = ''
        for key in model['fields'].keys():
            fieldnames += key + ","
        #remove trailing comma
        fieldnames = fieldnames[0:-1]

        queryStr = "{} | where type =~ '{}' | project {}".format(
            class_scope, base_name, fieldnames)
        if count:
            queryStr += " | count"
        else:
            queryStr += " | limit {}".format(page_size)

        logger.debug("QUERY STR : {}".format(queryStr))

        return queryStr

    def forge_item(self, input_dict, model):
        '''TODO function to forge outputs from Azure Resource Graph API'''

        return input_dict
예제 #11
0
 def resource_graph(self) -> ResourceGraphClient:
     return ResourceGraphClient(credential=self.credential)