Exemplo n.º 1
0
    def get_task_arns_for_location_arns(
        self,
        source_location_arns: list,
        destination_location_arns: list,
    ) -> list:
        """
        Return list of TaskArns for which use any one of the specified
        source LocationArns and any one of the specified destination LocationArns.

        :param list source_location_arns: List of source LocationArns.
        :param list destination_location_arns: List of destination LocationArns.
        :return: list
        :rtype: list(TaskArns)
        :raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty.
        """
        if not source_location_arns:
            raise AirflowBadRequest("source_location_arns not specified")
        if not destination_location_arns:
            raise AirflowBadRequest("destination_location_arns not specified")
        if not self.tasks:
            self._refresh_tasks()

        result = []
        for task in self.tasks:
            task_arn = task["TaskArn"]
            task_description = self.get_task_description(task_arn)
            if task_description["SourceLocationArn"] in source_location_arns:
                if task_description["DestinationLocationArn"] in destination_location_arns:
                    result.append(task_arn)
        return result
Exemplo n.º 2
0
def create_pool(name, slots, description, session=None):
    """Create a pool with given parameters."""
    if not (name and name.strip()):
        raise AirflowBadRequest("Pool name shouldn't be empty")

    try:
        slots = int(slots)
    except ValueError:
        raise AirflowBadRequest(f"Bad value for `slots`: {slots}")

    # Get the length of the pool column
    pool_name_length = Pool.pool.property.columns[0].type.length
    if len(name) > pool_name_length:
        raise AirflowBadRequest(
            f"Pool name can't be more than {pool_name_length} characters")

    session.expire_on_commit = False
    pool = session.query(Pool).filter_by(pool=name).first()
    if pool is None:
        pool = Pool(pool=name, slots=slots, description=description)
        session.add(pool)
    else:
        pool.slots = slots
        pool.description = description

    session.commit()

    return pool
Exemplo n.º 3
0
 def poke(self, context):
     logging.info(
         f"Getting status for statement {self.statement_id} "
         f"in session {self.session_id}"
     )
     endpoint = f"{ENDPOINT}/{self.session_id}/statements/{self.statement_id}"
     response = HttpHook(method="GET", http_conn_id=self.http_conn_id).run(endpoint)
     try:
         statement = json.loads(response.content)
         state = statement["state"]
     except (JSONDecodeError, LookupError) as ex:
         log_response_error("$.state", response, self.session_id, self.statement_id)
         raise AirflowBadRequest(ex)
     if state in ["waiting", "running"]:
         logging.info(
             f"Statement {self.statement_id} in session {self.session_id} "
             f"has not finished yet (state is '{state}')"
         )
         return False
     if state == "available":
         self.__check_status(statement, response)
         return True
     raise AirflowBadRequest(
         f"Statement {self.statement_id} in session {self.session_id} failed due to "
         f"an unknown state: '{state}'.\nKnown states: 'waiting', 'running', "
         "'available'"
     )
Exemplo n.º 4
0
 def create_pool(self, name, slots, description):
     if not (name and name.strip()):
         raise AirflowBadRequest("Pool name shouldn't be empty")
     pool_name_length = Pool.pool.property.columns[0].type.length
     if len(name) > pool_name_length:
         raise AirflowBadRequest(
             f"pool name cannot be more than {pool_name_length} characters")
     try:
         slots = int(slots)
     except ValueError:
         raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
     pool = Pool.create_or_update_pool(name=name,
                                       slots=slots,
                                       description=description)
     return pool.pool, pool.slots, pool.description
Exemplo n.º 5
0
 def spill_session_logs(self):
     dashes = 50
     logging.info(f"{'-'*dashes}Full log for session {self.session_id}{'-'*dashes}")
     endpoint = f"{ENDPOINT}/{self.session_id}/log"
     hook = HttpHook(method="GET", http_conn_id=self.http_conn_id)
     line_from = 0
     line_to = LOG_PAGE_LINES
     while True:
         log_page = self.fetch_log_page(hook, endpoint, line_from, line_to)
         try:
             logs = log_page["log"]
             for log in logs:
                 logging.info(log.replace("\\n", "\n"))
             actual_line_from = log_page["from"]
             total_lines = log_page["total"]
         except LookupError as ex:
             log_response_error("$.log, $.from, $.total", log_page, self.session_id)
             raise AirflowBadRequest(ex)
         actual_lines = len(logs)
         if actual_line_from + actual_lines >= total_lines:
             logging.info(
                 f"{'-' * dashes}End of full log for session {self.session_id}"
                 f"{'-' * dashes}"
             )
             break
         line_from = actual_line_from + actual_lines
Exemplo n.º 6
0
    def get_location_arns(self, location_uri, case_sensitive=True):
        """
        Return all LocationArns which match a LocationUri.

        :param str location_uri: Location URI to search for, eg ``s3://mybucket/mypath``
        :param bool case_sensitive: Do a case sensitive search for location URI.
        :return: List of LocationArns.
        :rtype: list(str)
        :raises AirflowBadRequest: if ``location_uri`` is empty
        """
        if not location_uri:
            raise AirflowBadRequest('location_uri not specified')
        if not self.locations:
            self._refresh_locations()
        result = []

        for location in self.locations:
            match = False
            if case_sensitive:
                match = location['LocationUri'] == location_uri
            else:
                match = location['LocationUri'].lower() == location_uri.lower()
            if match:
                result.append(location['LocationArn'])
        return result
Exemplo n.º 7
0
    def create_collection(
        self,
        collection_name: str,
        database_name: Optional[str] = None,
        partition_key: Optional[str] = None,
    ) -> None:
        """Creates a new collection in the CosmosDB database."""
        if collection_name is None:
            raise AirflowBadRequest("Collection name cannot be None.")

        # We need to check to see if this container already exists so we don't try
        # to create it twice
        existing_container = list(self.get_conn().get_database_client(
            self.__get_database_name(database_name)).query_containers(
                "SELECT * FROM r WHERE r.id=@id",
                parameters=[
                    json.dumps({
                        "name": "@id",
                        "value": collection_name
                    })
                ],
            ))

        # Only create if we did not find it already existing
        if len(existing_container) == 0:
            self.get_conn().get_database_client(
                self.__get_database_name(database_name)).create_container(
                    collection_name, partition_key=partition_key)
Exemplo n.º 8
0
 def poke(self, context):
     logging.info("Getting session {session_id} status...".format(
         session_id=self.session_id))
     endpoint = "{ENDPOINT}/{session_id}/state".format(
         ENDPOINT=ENDPOINT, session_id=self.session_id)
     response = HttpHook(method="GET",
                         http_conn_id=self.http_conn_id).run(endpoint)
     try:
         state = json.loads(response.content)["state"]
     except (JSONDecodeError, LookupError) as ex:
         log_response_error("$.state", response, self.session_id)
         raise AirflowBadRequest(ex)
     if state == "starting":
         logging.info("Session {session_id} is starting...".format(
             session_id=self.session_id))
         return False
     if state == "idle":
         logging.info(
             "Session {session_id} is ready to receive statements.".format(
                 session_id=self.session_id))
         return True
     raise AirflowException(
         "Session {session_id} failed to start. "
         "State='{state}'. Expected states: 'starting' or 'idle' (ready).".
         format(session_id=self.session_id, state=state))
 def spill_batch_logs(self):
     """Gets paginated batch logs from livy batch API and logs them"""
     if not self.connections_created:
         self.create_livy_connections()
     dashes = 50
     self.log.info(f"{'-'*dashes}Full log for batch %s{'-'*dashes}", self.batch_id)
     endpoint = f"{LIVY_ENDPOINT}/{self.batch_id}/log"
     hook = self.LocalConnHttpHook(self, method="GET", http_conn_id='livy_conn_id')
     line_from = 0
     line_to = LOG_PAGE_LINES
     while True:
         log_page = self._fetch_log_page(hook, endpoint, line_from, line_to)
         try:
             logs = log_page["log"]
             for log in logs:
                 self.log.info(log.replace("\\n", "\n"))
             actual_line_from = log_page["from"]
             total_lines = log_page["total"]
         except LookupError as ex:
             self._log_response_error("$.log, $.from, $.total", log_page)
             raise AirflowBadRequest(ex)
         actual_lines = len(logs)
         if actual_line_from + actual_lines >= total_lines:
             self.log.info(
                 f"{'-' * dashes}End of full log for batch %s"
                 f"{'-' * dashes}", self.batch_id
             )
             break
         line_from = actual_line_from + actual_lines
 def _check_spark_app_status(self, app_id):
     """
     Verifies whether this spark job has succeeded or failed
     by querying the spark history server
     :param app_id: application ID of the spark job
     :raises AirflowException: when the job is verified to have failed
     """
     self.log.info("Getting app status (id=%s) from Spark REST API...", app_id)
     endpoint = f"{SPARK_ENDPOINT}/{app_id}/jobs"
     response = self.LocalConnHttpHook(self, method="GET", http_conn_id='spark_conn_id').run(
         endpoint
     )
     try:
         jobs = json.loads(response.content)
         expected_status = "SUCCEEDED"
         for job in jobs:
             job_id = job["jobId"]
             job_status = job["status"]
             self.log.info(
                 "Job id %s associated with application '%s' is '%s'",
                 job_id, app_id, job_status
             )
             if job_status != expected_status:
                 raise AirflowException(
                     f"Job id '{job_id}' associated with application '{app_id}' "
                     f"is '{job_status}', expected status is '{expected_status}'"
                 )
     except (JSONDecodeError, LookupError, TypeError) as ex:
         self._log_response_error("$.jobId, $.status", response)
         raise AirflowBadRequest(ex)
Exemplo n.º 11
0
 def spill_batch_logs(self):
     dashes = '-' * 50
     logging.info("{dashes}Full log for batch {batch_id}{dashes}".format(
         dashes=dashes, batch_id=self.batch_id))
     endpoint = "{LIVY_ENDPOINT}/{batch_id}/log".format(
         LIVY_ENDPOINT=LIVY_ENDPOINT, batch_id=self.batch_id)
     hook = HttpHook(method="GET", http_conn_id=self.http_conn_id_livy)
     line_from = 0
     line_to = LOG_PAGE_LINES
     while True:
         log_page = self.fetch_log_page(hook, endpoint, line_from, line_to)
         try:
             logs = log_page["log"]
             for log in logs:
                 logging.info(log.replace("\\n", "\n"))
             actual_line_from = log_page["from"]
             total_lines = log_page["total"]
         except LookupError as ex:
             log_response_error("$.log, $.from, $.total", log_page)
             raise AirflowBadRequest(ex)
         actual_lines = len(logs)
         if actual_line_from + actual_lines >= total_lines:
             logging.info("{dashes}End of full log for batch {batch_id}"
                          "{dashes}".format(dashes=dashes,
                                            batch_id=self.batch_id))
             break
         line_from = actual_line_from + actual_lines
Exemplo n.º 12
0
 def check_spark_app_status(self, app_id):
     logging.info(
         "Getting app status (id={app_id}) from Spark REST API...".format(
             app_id=app_id))
     endpoint = "{SPARK_ENDPOINT}/{app_id}/jobs".format(
         SPARK_ENDPOINT=SPARK_ENDPOINT, app_id=app_id)
     response = HttpHook(method="GET",
                         http_conn_id=self.http_conn_id_spark).run(endpoint)
     try:
         jobs = json.loads(response.content)
         expected_status = "SUCCEEDED"
         for job in jobs:
             job_id = job["jobId"]
             job_status = job["status"]
             logging.info(
                 "Job id {job_id} associated with application '{app_id}' "
                 "is '{job_status}'".format(job_id=job_id,
                                            app_id=app_id,
                                            job_status=job_status))
             if job_status != expected_status:
                 raise AirflowException(
                     "Job id '{job_id}' associated with application '{app_id}' "
                     "is '{job_status}', expected status is '{expected_status}'"
                     .format(job_id=job_id,
                             app_id=app_id,
                             job_status=job_status,
                             expected_status=expected_status))
     except (JSONDecodeError, LookupError, TypeError) as ex:
         log_response_error("$.jobId, $.status", response)
         raise AirflowBadRequest(ex)
Exemplo n.º 13
0
    def get_documents(
        self,
        sql_string: str,
        database_name: Optional[str] = None,
        collection_name: Optional[str] = None,
        partition_key: Optional[str] = None,
    ) -> Optional[list]:
        """Get a list of documents from an existing collection in the CosmosDB database via SQL query."""
        if sql_string is None:
            raise AirflowBadRequest("SQL query string cannot be None")

        # Query them in SQL
        query = {'query': sql_string}

        try:
            result_iterable = (
                self.get_conn()
                .get_database_client(self.__get_database_name(database_name))
                .get_container_client(self.__get_collection_name(collection_name))
                .query_items(query, partition_key)
            )

            return list(result_iterable)
        except CosmosHttpResponseError:
            return None
Exemplo n.º 14
0
 def __check_status(self, statement, response):
     try:
         output = statement["output"]
         status = output["status"]
     except LookupError as ex:
         log_response_error("$.output.status", response, self.session_id,
                            self.statement_id)
         raise AirflowBadRequest(ex)
     pp_output = "\n".join(json.dumps(output, indent=2).split("\\n"))
     logging.info(
         f"Statement {self.statement_id} in session {self.session_id} "
         f"finished. Output:\n{pp_output}")
     if status != "ok":
         raise AirflowBadRequest(
             f"Statement {self.statement_id} in session {self.session_id} "
             f"failed with status '{status}'. Expected status is 'ok'")
Exemplo n.º 15
0
    def upsert_document(self, document, database_name=None, collection_name=None, document_id=None):
        """
        Inserts a new document (or updates an existing one) into an existing
        collection in the CosmosDB database.
        """
        # Assign unique ID if one isn't provided
        if document_id is None:
            document_id = str(uuid.uuid4())

        if document is None:
            raise AirflowBadRequest("You cannot insert a None document")

        # Add document id if isn't found
        if 'id' in document:
            if document['id'] is None:
                document['id'] = document_id
        else:
            document['id'] = document_id

        created_document = (
            self.get_conn()
            .get_database_client(self.__get_database_name(database_name))
            .get_container_client(self.__get_collection_name(collection_name))
            .upsert_item(document)
        )

        return created_document
Exemplo n.º 16
0
    def get_documents(
        self,
        sql_string: str,
        database_name: Optional[str] = None,
        collection_name: Optional[str] = None,
        partition_key: Optional[str] = None,
    ) -> Optional[list]:
        """Get a list of documents from an existing collection in the CosmosDB database via SQL query."""
        if sql_string is None:
            raise AirflowBadRequest("SQL query string cannot be None")

        # Query them in SQL
        query = {'query': sql_string}

        try:
            result_iterable = self.get_conn().QueryItems(
                get_collection_link(
                    self.__get_database_name(database_name),
                    self.__get_collection_name(collection_name)),
                query,
                partition_key,
            )

            return list(result_iterable)
        except HTTPFailure:
            return None
Exemplo n.º 17
0
    def get_location_arns(
        self, location_uri: str, case_sensitive: bool = False, ignore_trailing_slash: bool = True
    ) -> List[str]:
        """
        Return all LocationArns which match a LocationUri.

        :param str location_uri: Location URI to search for, eg ``s3://mybucket/mypath``
        :param bool case_sensitive: Do a case sensitive search for location URI.
        :param bool ignore_trailing_slash: Ignore / at the end of URI when matching.
        :return: List of LocationArns.
        :rtype: list(str)
        :raises AirflowBadRequest: if ``location_uri`` is empty
        """
        if not location_uri:
            raise AirflowBadRequest("location_uri not specified")
        if not self.locations:
            self._refresh_locations()
        result = []

        if not case_sensitive:
            location_uri = location_uri.lower()
        if ignore_trailing_slash and location_uri.endswith("/"):
            location_uri = location_uri[:-1]

        for location_from_aws in self.locations:
            location_uri_from_aws = location_from_aws["LocationUri"]
            if not case_sensitive:
                location_uri_from_aws = location_uri_from_aws.lower()
            if ignore_trailing_slash and location_uri_from_aws.endswith("/"):
                location_uri_from_aws = location_uri_from_aws[:-1]
            if location_uri == location_uri_from_aws:
                result.append(location_from_aws["LocationArn"])
        return result
Exemplo n.º 18
0
    def create_collection(self,
                          collection_name: str,
                          database_name: Optional[str] = None) -> None:
        """Creates a new collection in the CosmosDB database."""
        if collection_name is None:
            raise AirflowBadRequest("Collection name cannot be None.")

        # We need to check to see if this container already exists so we don't try
        # to create it twice
        existing_container = list(self.get_conn().QueryContainers(
            get_database_link(self.__get_database_name(database_name)),
            {
                "query": "SELECT * FROM r WHERE r.id=@id",
                "parameters": [{
                    "name": "@id",
                    "value": collection_name
                }],
            },
        ))

        # Only create if we did not find it already existing
        if len(existing_container) == 0:
            self.get_conn().CreateContainer(
                get_database_link(self.__get_database_name(database_name)),
                {"id": collection_name})
Exemplo n.º 19
0
def delete_pool(name, session=None):
    """Delete pool by a given name."""
    if not (name and name.strip()):
        raise AirflowBadRequest("Pool name shouldn't be empty")

    if name == Pool.DEFAULT_POOL_NAME:
        raise AirflowBadRequest("default_pool cannot be deleted")

    pool = session.query(Pool).filter_by(pool=name).first()
    if pool is None:
        raise PoolNotFound(f"Pool '{name}' doesn't exist")

    session.delete(pool)
    session.commit()

    return pool
Exemplo n.º 20
0
    def delete_collection(self, collection_name: str, database_name: Optional[str] = None) -> None:
        """Deletes an existing collection in the CosmosDB database."""
        if collection_name is None:
            raise AirflowBadRequest("Collection name cannot be None.")

        self.get_conn().get_database_client(self.__get_database_name(database_name)).delete_container(
            collection_name
        )
Exemplo n.º 21
0
    def delete_database(self, database_name):
        """
        Deletes an existing database in CosmosDB.
        """
        if database_name is None:
            raise AirflowBadRequest("Database name cannot be None.")

        self.get_conn().DeleteDatabase(get_database_link(database_name))
Exemplo n.º 22
0
 def fetch_log_page(hook: HttpHook, endpoint, line_from, line_to):
     prepd_endpoint = endpoint + f"?from={line_from}&size={line_to}"
     response = hook.run(prepd_endpoint)
     try:
         return json.loads(response.content)
     except JSONDecodeError as ex:
         log_response_error("$", response)
         raise AirflowBadRequest(ex)
Exemplo n.º 23
0
    def __get_collection_name(self, collection_name=None):
        coll_name = collection_name
        if coll_name is None:
            coll_name = self.default_collection_name

        if coll_name is None:
            raise AirflowBadRequest("Collection name must be specified")

        return coll_name
 def _fetch_log_page(self, hook: LocalConnHttpHook, endpoint, line_from, line_to):
     """fetch a paginated log page from the livy batch API"""
     prepd_endpoint = endpoint + f"?from={line_from}&size={line_to}"
     response = hook.run(prepd_endpoint)
     try:
         return json.loads(response.content)
     except JSONDecodeError as ex:
         self._log_response_error("$", response)
         raise AirflowBadRequest(ex)
Exemplo n.º 25
0
    def delete_collection(self, collection_name, database_name=None):
        """
        Deletes an existing collection in the CosmosDB database.
        """
        if collection_name is None:
            raise AirflowBadRequest("Collection name cannot be None.")

        self.get_conn().DeleteContainer(
            get_collection_link(self.__get_database_name(database_name), collection_name))
Exemplo n.º 26
0
    def __get_database_name(self, database_name=None):
        db_name = database_name
        if db_name is None:
            db_name = self.default_database_name

        if db_name is None:
            raise AirflowBadRequest("Database name must be specified")

        return db_name
Exemplo n.º 27
0
    def delete_document(
        self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
    ) -> None:
        """Delete an existing document out of a collection in the CosmosDB database."""
        if document_id is None:
            raise AirflowBadRequest("Cannot delete a document without an id")

        self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client(
            self.__get_collection_name(collection_name)
        ).delete_item(document_id)
Exemplo n.º 28
0
    def __get_collection_name(self, collection_name: Optional[str] = None) -> str:
        self.get_conn()
        coll_name = collection_name
        if coll_name is None:
            coll_name = self.default_collection_name

        if coll_name is None:
            raise AirflowBadRequest("Collection name must be specified")

        return coll_name
Exemplo n.º 29
0
    def __get_database_name(self, database_name: Optional[str] = None) -> str:
        self.get_conn()
        db_name = database_name
        if db_name is None:
            db_name = self.default_database_name

        if db_name is None:
            raise AirflowBadRequest("Database name must be specified")

        return db_name
Exemplo n.º 30
0
    def cancel_task_execution(self, task_execution_arn: str) -> None:
        """
        Cancel a TaskExecution for the specified ``task_execution_arn``.

        :param str task_execution_arn: TaskExecutionArn.
        :raises AirflowBadRequest: If ``task_execution_arn`` is empty.
        """
        if not task_execution_arn:
            raise AirflowBadRequest("task_execution_arn not specified")
        self.get_conn().cancel_task_execution(TaskExecutionArn=task_execution_arn)