class RestAPIClient(object):
    def __init__(self, credential, endpoint, custom_key_values={}):
        logger.info(
            "Initialising RestAPIClient, credential={}, endpoint={}".format(
                logger.filter_secrets(credential), endpoint))

        #  presets_variables contains all variables available in templates using the {{variable_name}} notation
        self.presets_variables = {}
        self.presets_variables.update(endpoint)
        self.presets_variables.update(credential)
        self.presets_variables.update(custom_key_values)

        #  requests_kwargs contains **kwargs used for requests
        self.requests_kwargs = {}

        self.endpoint_query_string = endpoint.get("endpoint_query_string", [])
        user_defined_keys = credential.get("user_defined_keys", [])
        self.user_defined_keys = self.get_params(user_defined_keys,
                                                 self.presets_variables)
        self.presets_variables.update(self.user_defined_keys)

        endpoint_url = endpoint.get("endpoint_url", "")
        self.endpoint_url = format_template(endpoint_url,
                                            **self.presets_variables)
        self.http_method = endpoint.get("http_method", "GET")

        endpoint_headers = endpoint.get("endpoint_headers", "")
        self.endpoint_headers = self.get_params(endpoint_headers,
                                                self.presets_variables)

        self.params = self.get_params(self.endpoint_query_string,
                                      self.presets_variables)

        self.extraction_key = endpoint.get("extraction_key", None)

        self.set_login(credential)

        self.requests_kwargs.update({"headers": self.endpoint_headers})
        self.ignore_ssl_check = endpoint.get("ignore_ssl_check", False)
        if self.ignore_ssl_check:
            self.requests_kwargs.update({"verify": False})
        else:
            self.requests_kwargs.update({"verify": True})
        self.timeout = endpoint.get("timeout", -1)
        if self.timeout > 0:
            self.requests_kwargs.update({"timeout": self.timeout})

        self.requests_kwargs.update({"params": self.params})
        self.pagination = Pagination()
        next_page_url_key = endpoint.get("next_page_url_key", "").split('.')
        top_key = endpoint.get("top_key")
        skip_key = endpoint.get("skip_key")
        pagination_type = endpoint.get("pagination_type", "na")
        self.pagination.configure_paging(skip_key=skip_key,
                                         limit_key=top_key,
                                         next_page_key=next_page_url_key,
                                         url=self.endpoint_url,
                                         pagination_type=pagination_type)
        self.last_interaction = None
        self.requests_per_minute = endpoint.get("requests_per_minute", -1)
        if self.requests_per_minute > 0:
            self.time_between_requests = 60 / self.requests_per_minute
        else:
            self.time_between_requests = None
        self.time_last_request = None
        self.loop_detector = LoopDetector()
        body_format = endpoint.get("body_format", None)
        if body_format == DKUConstants.RAW_BODY_FORMAT:
            text_body = endpoint.get("text_body", "")
            self.requests_kwargs.update({"data": text_body})
        elif body_format in [DKUConstants.FORM_DATA_BODY_FORMAT]:
            key_value_body = endpoint.get("key_value_body", {})
            self.requests_kwargs.update(
                {"json": get_dku_key_values(key_value_body)})

    def set_login(self, credential):
        login_type = credential.get("login_type", "no_auth")
        if login_type == "basic_login":
            self.username = credential.get("username", "")
            self.password = credential.get("password", "")
            self.auth = (self.username, self.password)
            self.requests_kwargs.update({"auth": self.auth})
        if login_type == "bearer_token":
            token = credential.get("token", "")
            bearer_template = credential.get("bearer_template",
                                             "Bearer {{token}}")
            bearer_template = bearer_template.replace("{{token}}", token)
            self.endpoint_headers.update({"Authorization": bearer_template})
        if login_type == "api_key":
            self.api_key_name = credential.get("api_key_name", "")
            self.api_key_value = credential.get("api_key_value", "")
            self.api_key_destination = credential.get("api_key_destination",
                                                      "header")
            if self.api_key_destination == "header":
                self.endpoint_headers.update(
                    {self.api_key_name: self.api_key_value})
            else:
                self.params.update({self.api_key_name: self.api_key_value})

    def get(self, url, can_raise_exeption=True, **kwargs):
        json_response = self.request("GET",
                                     url,
                                     can_raise_exeption=can_raise_exeption,
                                     **kwargs)
        return json_response

    def request(self, method, url, can_raise_exeption=True, **kwargs):
        logger.info(u"Accessing endpoint {} with params={}".format(
            url, kwargs.get("params")))
        self.enforce_throttling()
        kwargs = template_dict(kwargs, **self.presets_variables)
        if self.loop_detector.is_stuck_in_loop(url, kwargs.get("params", {}),
                                               kwargs.get("headers", {})):
            raise RestAPIClientError(
                "The api-connect plugin is stuck in a loop. Please check the pagination parameters."
            )
        try:
            response = requests.request(method, url, **kwargs)
        except Exception as err:
            self.pagination.is_last_batch_empty = True
            error_message = "Error: {}".format(err)
            if can_raise_exeption:
                raise RestAPIClientError(error_message)
            else:
                return {"error": error_message}
        self.time_last_request = time.time()
        if response.status_code >= 400:
            error_message = "Error {}: {}".format(response.status_code,
                                                  response.content)
            self.pagination.is_last_batch_empty = True
            if can_raise_exeption:
                raise RestAPIClientError(error_message)
            else:
                return {"error": error_message}
        json_response = response.json()
        self.pagination.update_next_page(json_response)
        return json_response

    def paginated_api_call(self, can_raise_exeption=True):
        pagination_params = self.pagination.get_params()
        params = self.requests_kwargs.get("params")
        params.update(pagination_params)
        self.requests_kwargs.update({"params": params})
        return self.request(self.http_method,
                            self.pagination.get_next_page_url(),
                            can_raise_exeption, **self.requests_kwargs)

    @staticmethod
    def get_params(endpoint_query_string, keywords):
        templated_query_string = get_dku_key_values(endpoint_query_string)
        ret = {}
        for key in templated_query_string:
            ret.update({
                key:
                format_template(templated_query_string.get(key, ""), **
                                keywords) or ""
            })
        return ret

    def has_more_data(self):
        if not self.pagination.is_paging_started:
            self.start_paging()
        return self.pagination.has_next_page()

    def start_paging(self):
        logger.info("Start paging with counting key '{}'".format(
            self.extraction_key))
        self.pagination.reset_paging(counting_key=self.extraction_key,
                                     url=self.endpoint_url)

    def enforce_throttling(self):
        if self.time_between_requests and self.time_last_request:
            current_time = time.time()
            time_since_last_resquests = current_time - self.time_last_request
            if time_since_last_resquests < self.time_between_requests:
                logger.info("Enforcing {}s throttling".format(
                    self.time_between_requests - time_since_last_resquests))
                time.sleep(self.time_between_requests -
                           time_since_last_resquests)
Exemple #2
0
class JiraClient(object):

    JIRA_SITE_URL = "https://{subdomain}.atlassian.net/"
    OPSGENIE_SITE_URL = "https://{subdomain}.opsgenie.com/"

    def __init__(self, connection_details, api_name="jira"):
        logger.info("JiraClient init")
        self.api_name = api_name
        self.api_url = normalize_url(connection_details.get("api_url", ""))
        self.server_type = connection_details.get("server_type", "cloud")
        self.username = connection_details.get("username", "")
        self.password = connection_details.get("token", "")
        self.subdomain = connection_details.get("subdomain")
        self.ignore_ssl_check = connection_details.get("ignore_ssl_check",
                                                       False)
        self.site_url = self.get_site_url()
        self.params = {}
        self.pagination = Pagination()

    def start_session(self, endpoint_name):
        self.endpoint_name = endpoint_name
        self.endpoint_descriptor = self.get_endpoint_descriptor(endpoint_name)
        self.formating = self.endpoint_descriptor.get(api.COLUMN_FORMATING, [])
        self.expanding = self.endpoint_descriptor.get(api.COLUMN_EXPANDING, [])
        self.cleaning = self.endpoint_descriptor.get(api.COLUMN_CLEANING, [])
        if self.formating == [] and self.expanding == [] and self.cleaning == []:
            self.format = self.return_data
        else:
            self.format = self.format_data

    def get_site_url(self):
        if self.is_opsgenie_api():
            return self.OPSGENIE_SITE_URL.format(subdomain=self.subdomain)
        else:
            if self.server_type == "cloud":
                return self.JIRA_SITE_URL.format(subdomain=self.subdomain)
            else:
                return self.api_url

    def is_opsgenie_api(self):
        return self.api_name == "opsgenie"

    def get_url(self, endpoint_name, item_value, queue_id):
        api_url = self.endpoint_descriptor[api.API]
        return api_url.format(site_url=self.site_url,
                              resource_name=self.get_resource_name(
                                  endpoint_name, item_value, queue_id))

    def get_resource_name(self, endpoint_name, item_value, queue_id):
        if item_value is not None:
            ressource_structure = self.get_ressource_structure()
            args = {
                "endpoint_name": endpoint_name,
                "item_value": item_value,
                "queue_id": queue_id
            }
            return ressource_structure.format(**args)
        else:
            return "{}".format(endpoint_name)

    def get_ressource_structure(self):
        return self.endpoint_descriptor[api.API_RESOURCE]

    def get_endpoint(self,
                     endpoint_name,
                     item_value,
                     data,
                     queue_id=None,
                     expand=[],
                     raise_exception=True):
        self.endpoint_name = endpoint_name
        self.params = self.get_params(endpoint_name, item_value, queue_id,
                                      expand)
        url = self.get_url(endpoint_name, item_value, queue_id)
        self.start_paging(counting_key=self.get_data_filter_key(), url=url)
        response = self.get(url, data, params=self.params)
        if response.status_code >= 400:
            self.pagination.set_error_flag(True)
            error_template = self.get_error_messages_template(
                response.status_code)
            jira_error_message = self.get_jira_error_message(response)
            error_message = error_template.format(
                endpoint_name=endpoint_name,
                item_value=item_value,
                queue_id=queue_id,
                status_code=response.status_code,
                jira_error_message=jira_error_message)
            if raise_exception:
                raise Exception("{}".format(error_message))
            else:
                return [{"error": error_message}]

        data = response.json()
        self.pagination.update_next_page(data)
        return self.filter_data(data, item_value)

    def start_paging(self, counting_key, url):
        pagination_config = self.get_pagination_config()
        self.pagination.configure_paging(pagination_config)
        self.pagination.reset_paging(counting_key=self.get_data_filter_key(),
                                     url=url)

    def get_params(self, endpoint_name, item_value, queue_id, expand=[]):
        ret = {}
        query_string_dict = self.get_query_string_dict()
        for key in query_string_dict:
            query_string_template = query_string_dict[key]
            query_string_value = query_string_template.format(
                endpoint_name=endpoint_name,
                item_value=item_value,
                queue_id=queue_id,
                expand=",".join(expand))
            ret.update({key: query_string_value})
        return ret

    def get_query_string_dict(self):
        query_string_template = self.endpoint_descriptor.get(
            api.API_QUERY_STRING, {})
        return query_string_template

    def get_pagination_config(self):
        pagination_config = self.endpoint_descriptor.get(api.PAGINATION, {})
        return pagination_config

    @staticmethod
    def get_jira_error_message(response):
        try:
            json = response.json()
            if api.API_ERROR_MESSAGES in json and len(
                    json[api.API_ERROR_MESSAGES]) > 0:
                return json[api.API_ERROR_MESSAGES][0]
            else:
                return ""
        except Exception:
            return response.text

    def get_endpoint_descriptor(self, endpoint_name):
        endpoint_descriptor = copy.deepcopy(
            api.endpoint_descriptors[api.API_DEFAULT_DESCRIPTOR])
        if endpoint_name in api.endpoint_descriptors[api.ENDPOINTS]:
            update_dict(endpoint_descriptor,
                        api.endpoint_descriptors[api.ENDPOINTS][endpoint_name])
        return endpoint_descriptor

    def filter_data(self, data, item_value):
        filtering_key = self.get_data_filter_key()
        if isinstance(filtering_key, list):
            if item_value == "":
                filtering_key = filtering_key[FILTERING_KEY_WITHOUT_PARAMETER]
            else:
                filtering_key = filtering_key[FILTERING_KEY_WITH_PARAMETER]
        if filtering_key is None:
            return arrayed(data)
        else:
            return arrayed(data[filtering_key])

    def format_data(self, data):
        for key in self.formating:
            path = self.formating[key]
            data[key] = extract(data, path)
        for key in self.expanding:
            data = self.expand(data, key)
        for key in self.cleaning:
            data.pop(key, None)
        return escape_json(data)

    def return_data(self, data):
        return escape_json(data)

    def expand(self, dictionary, key_to_expand):
        if key_to_expand in dictionary:
            self.dig(dictionary, dictionary[key_to_expand], [key_to_expand])
            dictionary.pop(key_to_expand, None)
        return dictionary

    def dig(self, dictionary, element_to_expand, path_to_element):
        if not isinstance(element_to_expand, dict):
            dictionary["_".join(path_to_element)] = element_to_expand
        else:
            for key in element_to_expand:
                new_path = copy.deepcopy(path_to_element)
                new_path.append(key)
                self.dig(dictionary, element_to_expand[key], new_path)

    def get_data_filter_key(self):
        if (api.API_RETURN in self.endpoint_descriptor) and (
                200 in self.endpoint_descriptor[api.API_RETURN]):
            key = self.endpoint_descriptor[api.API_RETURN][200]
        else:
            key = None
        return key

    def get_error_messages_template(self, status_code):
        error_messages_template = api.endpoint_descriptors[
            api.API_DEFAULT_DESCRIPTOR][api.API_RETURN]
        if api.API_RETURN in self.endpoint_descriptor:
            update_dict(error_messages_template,
                        self.endpoint_descriptor[api.API_RETURN])
        if status_code in error_messages_template:
            return error_messages_template[status_code]
        else:
            return "Error {status_code} - {jira_error_message}"

    def get(self, url, data=None, params=None):
        params = {} if params is None else params
        args = {}
        headers = self.get_headers()
        if headers is not None:
            args.update({"headers": headers})
        auth = self.get_auth()
        if auth is not None:
            args.update({"auth": auth})
        if data is not None:
            args.update({"data": data})
        if self.ignore_ssl_check:
            args.update({"verify": False})
        params.update(self.pagination.get_params())
        if params != {}:
            args.update({"params": params})
        logger.info("Access Jira on endppoint {}".format(url))
        response = requests.get(url, **args)
        return response

    def get_auth(self):
        if self.is_opsgenie_api():
            return None
        else:
            return (self.username, self.password)

    def get_headers(self):
        headers = {}
        headers["X-ExperimentalApi"] = "opt-in"
        if self.is_opsgenie_api():
            headers["Authorization"] = self.get_auth_headers()
        return headers

    def get_auth_headers(self):
        return "GenieKey {}".format(self.password)

    def get_next_page(self):
        logger.info("Loading next page")
        response = self.get(self.pagination.get_next_page_url(),
                            params=self.params)
        if response.status_code >= 400:
            error_message = self.get_error_messages_template(
                response.status_code).format(endpoint_name=self.endpoint_name)
            raise Exception("{}".format(error_message))
        data = response.json()
        self.pagination.update_next_page(data)
        return self.filter_data(data, None)