Example #1
0
 def test_set_logger_level_debug():
     os.environ['LOG_LEVEL'] = 'DEBUG'
     logger = log_service.get_module_logger(__name__)
     assert logger.level == 10
     assert logging.getLevelName(logger.level) == 'DEBUG'
import logging
import re
import os
import tempfile
import shutil
import stat
import json
from git import Repo
from google.cloud import storage
from composer.utils import log_service, auth_service
from composer.airflow import airflow_service
from composer.dag import dag_validator, dag_generator

# gets the logger for this module
logger = log_service.get_module_logger(__name__)


# [START __get_composer_environment]
def __get_composer_environment(project_id, location, composer_environment):
    """
    Gets the configuration information for a Cloud Composer environment.
    Args:
        project_id (string): GCP Project Id of the Cloud Composer instance
        location (string): GCP Zone of the Cloud Composer instance
        composer_environment (string): Name of the Cloud Composer instance
    Returns:
        an instance of composer.airflow.AirflowService
    """
    logger.log(logging.DEBUG, "Getting the composer environment")
    authenticated_session = auth_service.get_authenticated_session()
Example #3
0
 def test_set_logger_level_error():
     os.environ['LOG_LEVEL'] = 'ERROR'
     logger = log_service.get_module_logger(__name__)
     assert logger.level == 40
     assert logging.getLevelName(logger.level) == 'ERROR'
Example #4
0
 def test_set_logger_level_critical():
     os.environ['LOG_LEVEL'] = 'CRITICAL'
     logger = log_service.get_module_logger(__name__)
     assert logger.level == 50
     assert logging.getLevelName(logger.level) == 'CRITICAL'
Example #5
0
 def test_set_logger_level_warning():
     os.environ['LOG_LEVEL'] = 'WARNING'
     logger = log_service.get_module_logger(__name__)
     assert logger.level == 30
     assert logging.getLevelName(logger.level) == 'WARNING'
Example #6
0
 def test_set_logger_level_info():
     os.environ['LOG_LEVEL'] = 'INFO'
     logger = log_service.get_module_logger(__name__)
     assert logger.level == 20
     assert logging.getLevelName(logger.level) == 'INFO'
Example #7
0
class DagGenerator:
    """Class used to generate an Airflow DAG based on a JSON DSL definition"""

    # [START global variable definitions]
    INSERTION_MARKER = "# >>>INSERTION_MARKER<<<"
    DAG_TEMPLATE = "dag_template.py"
    EXTENSION_PYTHON = ".py"
    EXTENSION_JSON = ".json"
    # [END global variable definitions]

    # gets the logger for this module
    logger = log_service.get_module_logger(__name__)

    # [START DagGenerator constructor]
    def __init__(self, payload):
        # set class logger
        # sanitize the dag_name - replace whitespace with underscores and convert to lowercase
        self.payload = payload
        self.dag_name = re.sub(r'\s+', '_', payload['dag_name']).lower()
        self.temp_dir = tempfile.gettempdir()
        # define the path for the dag file and its associated json data
        # these are the paths where the concrete dag will be created
        self.dag_file = os.path.join(self.temp_dir, f"{self.dag_name}{self.EXTENSION_PYTHON}")
        self.json_file = os.path.join(self.temp_dir, f"{self.dag_name}{self.EXTENSION_JSON}")
    # [END DagGenerator constructor]

    # [START remove_previous_versions]
    def remove_previous_versions(self):
        """Removes any existing dag or json file that uses the provided dag name"""
        self.logger.log(logging.DEBUG, "Removes any previous version of the dag.")
        if os.path.exists(self.dag_file):
            os.remove(self.dag_file)
            self.logger.log(logging.INFO, f"Removed previous dag version: {self.dag_file}")
        else:
            self.logger.log(logging.INFO, f"No previous dag version: {self.dag_file}")

        if os.path.exists(self.json_file):
            os.remove(self.json_file)
            self.logger.log(logging.INFO, f"Removed previous json version: {self.json_file}")
        else:
            self.logger.log(logging.INFO, f"No previous json version: {self.json_file}")
    # [END remove_previous_versions]

    # [START copy_dag_template_to_file]
    def copy_dag_template_to_file(self):
        """Copies the dag template to a concrete dag file"""
        self.logger.log(logging.DEBUG, "Copying the dag template to a file.")
        template_src = os.path.join(os.path.dirname(Path(__file__)), self.DAG_TEMPLATE)
        self.logger.log(logging.INFO, f"Copying template: {template_src} to {self.dag_file}")
        shutil.copy(template_src, self.dag_file)
    # [END copy_dag_template_to_file]

    # [START write_payload_to_file]
    def write_payload_to_file(self):
        """Writes the provided, in-memory json payload to a concrete json file"""
        self.logger.log(logging.INFO, f"Writing payload to: {self.json_file}")
        with open(self.json_file, 'w') as f:
            json.dump(self.payload, f, indent=4)
    # [END write_payload_to_file]

    # [START insert_dynamic_data_to_dag]
    def insert_dynamic_data_to_dag(self):
        """Inserts dynamic data into the concrete dag file at the position defined by INSERTION_MARKER"""
        self.logger.log(logging.DEBUG, "Inserting the dynamic data into the dag file.")
        # read the concrete dag file and find the insertion marker
        with open(self.dag_file, "r") as f:
            contents = f.readlines()
            for i, line in enumerate(contents):
                if self.INSERTION_MARKER in line:
                    insertion_pos = i + 1
                    break

        """ 
        Inserts the path to the concrete json data file
            with open(os.path.join(os.path.dirname(Path(__file__)), 'json_file.json')) as f:
                payload = json.load(f)
        """
        contents.insert(
            insertion_pos,
            f"with open(os.path.join(os.path.dirname(Path(__file__)), "
            f"'{ntpath.basename(self.json_file)}')) as f:\n"
        )
        contents.insert(insertion_pos+1, f"    payload = json.load(f)")

        # write back the modified contents to the concrete dag file
        with open(self.dag_file, "w") as f:
            f.writelines(contents)
    # [END insert_dynamic_data_to_dag]

    # [START generate_dag]
    def generate_dag(self):
        """
        Generates a concrete dag file with its associated payload data in a concrete json file.
        Returns:
            a tuple containing the path to the dag_file and the path to the json_file
        """
        self.logger.log(logging.DEBUG, "Generating the dag file.")
        self.remove_previous_versions()
        self.copy_dag_template_to_file()
        self.write_payload_to_file()
        self.insert_dynamic_data_to_dag()
        return {'dag_file': self.dag_file, 'json_file': self.json_file}
class DagValidator:
    """Class used to validate and inspect an Airflow DAG"""

    # gets the logger for this module
    logger = log_service.get_module_logger(__name__)

    # [START DagValidator constructor]
    def __init__(self, dag_file):
        # set class logger
        self.dag_file = dag_file

    # [END DagValidator constructor]

    # [START load_dag_module]
    def load_dag_module(self):
        """
        Dynamically loads a concrete DAG file as a python module
        Returns:
            an instance of airflow.models.DAG
        """
        self.logger.log(logging.DEBUG, "Loading the dag module.")
        module_name = os.path.splitext(self.dag_file)[0]
        self.logger.log(
            logging.INFO,
            f"Loading dag module name: {module_name} from dag file: {self.dag_file}"
        )
        spec = importlib.util.spec_from_file_location(module_name,
                                                      self.dag_file)
        dag_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(dag_module)
        return dag_module

    # [END load_dag_module]

    # [START assert_has_valid_dag]
    def assert_has_valid_dag(self):
        """
        Assert that a module contains a valid DAG.
        Returns:
            a boolean indicating if a dag is found. True == found, False == not found
        """
        self.logger.log(logging.DEBUG,
                        "Asserting if the provided dag is valid.")
        dag_module = self.load_dag_module()

        no_dag_found = True

        for dag in vars(dag_module).values():
            if isinstance(dag, models.DAG):
                self.logger.log(logging.INFO,
                                f"{dag_module} is a DAG instance")
                no_dag_found = False
                dag.test_cycle()  # Throws if a task cycle is found.

        if no_dag_found:
            raise AssertionError(
                f"DAG file {self.dag_file} does not contain a valid DAG")

        return no_dag_found

    # [END assert_has_valid_dag]

    # [START validate_dag]
    def validate_dag(self):
        """Verifies that a concrete DAG file is valid."""
        self.logger.log(logging.DEBUG,
                        "Validating if the provided dag is valid.")
        self.assert_has_valid_dag()

    # [START validate_dag]

    # [START safe_serialize]
    def safe_serialize(self, obj):
        """
        Safely serialize a Python object to json - even when there are non-serializable attributes.
        Args:
            obj (Object): the object to be serialized
        Returns:
            a json string representing the serialized object
        """
        self.logger.log(logging.DEBUG, "Safely serializing the provided dag.")
        return json.dumps(
            obj,
            default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>",
            indent=4)

    # [END safe_serialize]

    # [START inspect_dag]
    def inspect_dag(self):
        """
        Loads a DAG, validates the DAG and then dumps the DAG structure to a JSON string.
        Returns:
            an instance of the validated dag represented as a json string
        """
        self.logger.log(logging.DEBUG, "Inspecting the provided dag.")
        dag_module = self.load_dag_module()

        no_dag_found = True

        dag_as_str = ""

        for dag in vars(dag_module).values():
            if isinstance(dag, models.DAG):
                self.logger.log(logging.INFO,
                                f"{dag_module} is a DAG instance")

                no_dag_found = False

                # put the dag details into a dictionary
                dag_dict = vars(dag)

                # iterate the task details and add to array
                tasks = []
                for task in dag.tasks:
                    tasks.append(vars(task))

                # inject the task details array into the dag details dictionary
                dag_dict['tasks_details'] = tasks

                # safely serialize the dict to json
                dag_as_str = self.safe_serialize(dag_dict)

        if no_dag_found:
            raise AssertionError(
                f"DAG file {self.dag_file} does not contain a valid DAG")

        return dag_as_str
class AirflowService:
    """Class to interact with GCP Cloud Composer (Apache Airflow)"""

    IAM_SCOPE = 'https://www.googleapis.com/auth/iam'
    OAUTH_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token'

    # gets the logger for this module
    logger = log_service.get_module_logger(__name__)

    # [START AirflowService constructor]
    def __init__(self, authenticated_session, project_id, location,
                 composer_environment):
        """
        AirflowService constructor.
        Args:
            authenticated_session (google.auth.transport.requests.AuthorizedSession): A GCP authenticated session
            project_id (string): GCP Project Id of the Cloud Composer instance
            location (string): GCP Zone of the Cloud Composer instance
            composer_environment (string): Name of the Cloud Composer instance
        """
        self.authenticated_session = authenticated_session
        self.project_id = project_id
        self.location = location
        self.composer_environment = composer_environment

    # [END AirflowService constructor]

    # [START get_airflow_config]
    def get_airflow_config(self):
        """
        Gets the details of the Cloud Composer environment
        Returns:
            a dictionary containing the details of the Cloud Composer environment
        """
        self.logger.log(logging.DEBUG, "Entered get_airflow_config method")
        environment_url = (
            'https://composer.googleapis.com/v1beta1/projects/{}/locations/{}'
            '/environments/{}').format(self.project_id, self.location,
                                       self.composer_environment)
        self.logger.log(logging.DEBUG,
                        f"Cloud Composer environment URL: {environment_url}")
        composer_response = self.authenticated_session.request(
            'GET', environment_url)
        environment_data = composer_response.json()
        airflow_uri = environment_data['config']['airflowUri']

        # The Composer environment response does not include the IAP client ID.
        # Make a second, unauthenticated HTTP request to the web server to get the
        # redirect URI.
        redirect_response = requests.get(airflow_uri, allow_redirects=False)
        redirect_location = redirect_response.headers['location']

        # Extract the client_id query parameter from the redirect.
        parsed = six.moves.urllib.parse.urlparse(redirect_location)
        query_string = six.moves.urllib.parse.parse_qs(parsed.query)
        environment_data['query_string'] = query_string
        return environment_data

    # [END get_airflow_config]

    # [START get_airflow_experimental_api]
    def get_airflow_experimental_api(self):
        """
        Gets the details of the Cloud Composer experimental API
        Returns:
            a tuple containing the airflow experimental api uri path and airflow client id
        """
        self.logger.log(logging.DEBUG,
                        "Entered get_airflow_experimental_api method")
        environment_data = self.get_airflow_config()
        airflow_uri = environment_data['config']['airflowUri']
        client_id = environment_data['query_string']['client_id'][0]
        return f"{airflow_uri}/api/experimental", client_id

    # [END get_airflow_experimental_api]

    # [START get_airflow_dag_gcs]
    def get_airflow_dag_gcs(self):
        """
        Gets the Google Cloud Storage path for the dag files in the Cloud Composer environment
        Returns:
            the name of the Cloud Composer Google Cloud Storage dag file bucket
        """
        self.logger.log(logging.DEBUG, "Entered get_airflow_dag_gcs method")
        environment_data = self.get_airflow_config()
        self.logger.log(
            logging.INFO,
            f"Google Cloud Storage path for the dag files: {environment_data['config']['dagGcsPrefix']}"
        )
        return environment_data['config']['dagGcsPrefix']

    # [END get_airflow_dag_gcs]

    # [START trigger_dag]
    def trigger_dag(self, dag_name, airflow_uri, client_id, data=None):
        """
        Makes a POST request to the Cloud Composer experimental API to trigger a DAG
        Returns:
            the page body, or raises an exception if the page couldn't be retrieved.
        """
        self.logger.log(logging.DEBUG, "Entered trigger_dag method")
        self.logger.log(
            logging.DEBUG,
            f"trigger_dag method params. dag_name: {dag_name}, airflow_uri: {airflow_uri}, client_id: {client_id}, data: {data}"
        )
        webserver_url = f"{airflow_uri}/dags/{dag_name}/dag_runs"
        self.logger.log(logging.INFO, f"Web server URL: {webserver_url}")
        # Make a POST request to IAP which then Triggers the DAG
        if data:
            return self.make_post_iap_request(webserver_url, client_id,
                                              {"conf": data})
        else:
            return self.make_post_iap_request(webserver_url, client_id, {})

    # [END trigger_dag]

    # [START make_post_iap_request]
    # This code is copied from
    # https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/iap/make_iap_request.py
    # START COPIED IAP CODE
    def make_post_iap_request(self, url, client_id, json, **kwargs):
        """Makes a POST request to an application protected by Identity-Aware Proxy.
        Args:
          url: The Identity-Aware Proxy-protected URL to fetch.
          client_id: The client ID used by Identity-Aware Proxy.
          json: A JSON payload containing any additional data to be included with the POST request.
          **kwargs: Any of the parameters defined for the request function:
                    https://github.com/requests/requests/blob/master/requests/api.py
                    If no timeout is provided, it is set to 90 by default.
        Returns:
          The page body, or raises an exception if the page couldn't be retrieved.
        """

        self.logger.log(logging.DEBUG, "Entered make_post_iap_request method")
        self.logger.log(
            logging.DEBUG,
            f"make_post_iap_request method params. url: {url}, client_id: {client_id}, json: {json}"
        )

        # Set the default timeout, if missing
        if 'timeout' not in kwargs:
            kwargs['timeout'] = 90

        # Obtain an OpenID Connect (OIDC) token from metadata server or using service
        # account.

        # try to get the id token from the auth_service
        google_open_id_connect_token = auth_service.get_id_token(
            Request(), client_id)
        if google_open_id_connect_token is None:
            google_open_id_connect_token = id_token.fetch_id_token(
                Request(), client_id)

        # Fetch the Identity-Aware Proxy-protected URL, including an
        # Authorization header containing "Bearer " followed by a
        # Google-issued OpenID Connect token for the service account.
        resp = requests.request(
            'POST',
            url,
            headers={
                'Authorization':
                'Bearer {}'.format(google_open_id_connect_token),
                'Content-Type': 'application/json'
            },
            json=json,
            **kwargs)
        if resp.status_code == 403:
            raise Exception('Service account does not have permission to '
                            'access the IAP-protected application.')
        elif resp.status_code != 200:
            raise Exception(
                'Bad response from application: {!r} / {!r} / {!r}'.format(
                    resp.status_code, resp.headers, resp.text))
        else:
            return resp.text

    # END COPIED IAP CODE
    # [END make_post_iap_request]

    # [START make_get_iap_request]
    # This code is copied from
    # https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/iap/make_iap_request.py
    # START COPIED IAP CODE
    def make_get_iap_request(self, url, client_id, **kwargs):
        """Makes a GET request to an application protected by Identity-Aware Proxy.
        Args:
          url: The Identity-Aware Proxy-protected URL to fetch.
          client_id: The client ID used by Identity-Aware Proxy.
          **kwargs: Any of the parameters defined for the request function:
                    https://github.com/requests/requests/blob/master/requests/api.py
                    If no timeout is provided, it is set to 90 by default.
        Returns:
          The page body, or raises an exception if the page couldn't be retrieved.
        """

        self.logger.log(logging.DEBUG, "Entered make_get_iap_request method")
        self.logger.log(
            logging.DEBUG,
            f"make_get_iap_request method params. url: {url}, client_id: {client_id}"
        )

        # Set the default timeout, if missing
        if 'timeout' not in kwargs:
            kwargs['timeout'] = 90

        # try to get the id token from the auth_service
        google_open_id_connect_token = auth_service.get_id_token(
            Request(), client_id)
        if google_open_id_connect_token is None:
            google_open_id_connect_token = id_token.fetch_id_token(
                Request(), client_id)

        resp = requests.request(
            'GET',
            url,
            headers={
                'Authorization':
                'Bearer {}'.format(google_open_id_connect_token)
            },
            **kwargs)
        if resp.status_code == 403:
            raise Exception('Service account does not have permission to '
                            'access the IAP-protected application.')
        elif resp.status_code != 200:
            raise Exception(
                'Bad response from application: {!r} / {!r} / {!r}'.format(
                    resp.status_code, resp.headers, resp.text))
        else:
            return resp.text