def test_set_logger_level_debug(): os.environ['LOG_LEVEL'] = 'DEBUG' logger = log_service.get_module_logger(__name__) assert logger.level == 10 assert logging.getLevelName(logger.level) == 'DEBUG'
import logging import re import os import tempfile import shutil import stat import json from git import Repo from google.cloud import storage from composer.utils import log_service, auth_service from composer.airflow import airflow_service from composer.dag import dag_validator, dag_generator # gets the logger for this module logger = log_service.get_module_logger(__name__) # [START __get_composer_environment] def __get_composer_environment(project_id, location, composer_environment): """ Gets the configuration information for a Cloud Composer environment. Args: project_id (string): GCP Project Id of the Cloud Composer instance location (string): GCP Zone of the Cloud Composer instance composer_environment (string): Name of the Cloud Composer instance Returns: an instance of composer.airflow.AirflowService """ logger.log(logging.DEBUG, "Getting the composer environment") authenticated_session = auth_service.get_authenticated_session()
def test_set_logger_level_error(): os.environ['LOG_LEVEL'] = 'ERROR' logger = log_service.get_module_logger(__name__) assert logger.level == 40 assert logging.getLevelName(logger.level) == 'ERROR'
def test_set_logger_level_critical(): os.environ['LOG_LEVEL'] = 'CRITICAL' logger = log_service.get_module_logger(__name__) assert logger.level == 50 assert logging.getLevelName(logger.level) == 'CRITICAL'
def test_set_logger_level_warning(): os.environ['LOG_LEVEL'] = 'WARNING' logger = log_service.get_module_logger(__name__) assert logger.level == 30 assert logging.getLevelName(logger.level) == 'WARNING'
def test_set_logger_level_info(): os.environ['LOG_LEVEL'] = 'INFO' logger = log_service.get_module_logger(__name__) assert logger.level == 20 assert logging.getLevelName(logger.level) == 'INFO'
class DagGenerator: """Class used to generate an Airflow DAG based on a JSON DSL definition""" # [START global variable definitions] INSERTION_MARKER = "# >>>INSERTION_MARKER<<<" DAG_TEMPLATE = "dag_template.py" EXTENSION_PYTHON = ".py" EXTENSION_JSON = ".json" # [END global variable definitions] # gets the logger for this module logger = log_service.get_module_logger(__name__) # [START DagGenerator constructor] def __init__(self, payload): # set class logger # sanitize the dag_name - replace whitespace with underscores and convert to lowercase self.payload = payload self.dag_name = re.sub(r'\s+', '_', payload['dag_name']).lower() self.temp_dir = tempfile.gettempdir() # define the path for the dag file and its associated json data # these are the paths where the concrete dag will be created self.dag_file = os.path.join(self.temp_dir, f"{self.dag_name}{self.EXTENSION_PYTHON}") self.json_file = os.path.join(self.temp_dir, f"{self.dag_name}{self.EXTENSION_JSON}") # [END DagGenerator constructor] # [START remove_previous_versions] def remove_previous_versions(self): """Removes any existing dag or json file that uses the provided dag name""" self.logger.log(logging.DEBUG, "Removes any previous version of the dag.") if os.path.exists(self.dag_file): os.remove(self.dag_file) self.logger.log(logging.INFO, f"Removed previous dag version: {self.dag_file}") else: self.logger.log(logging.INFO, f"No previous dag version: {self.dag_file}") if os.path.exists(self.json_file): os.remove(self.json_file) self.logger.log(logging.INFO, f"Removed previous json version: {self.json_file}") else: self.logger.log(logging.INFO, f"No previous json version: {self.json_file}") # [END remove_previous_versions] # [START copy_dag_template_to_file] def copy_dag_template_to_file(self): """Copies the dag template to a concrete dag file""" self.logger.log(logging.DEBUG, "Copying the dag template to a file.") template_src = os.path.join(os.path.dirname(Path(__file__)), self.DAG_TEMPLATE) self.logger.log(logging.INFO, f"Copying template: {template_src} to {self.dag_file}") shutil.copy(template_src, self.dag_file) # [END copy_dag_template_to_file] # [START write_payload_to_file] def write_payload_to_file(self): """Writes the provided, in-memory json payload to a concrete json file""" self.logger.log(logging.INFO, f"Writing payload to: {self.json_file}") with open(self.json_file, 'w') as f: json.dump(self.payload, f, indent=4) # [END write_payload_to_file] # [START insert_dynamic_data_to_dag] def insert_dynamic_data_to_dag(self): """Inserts dynamic data into the concrete dag file at the position defined by INSERTION_MARKER""" self.logger.log(logging.DEBUG, "Inserting the dynamic data into the dag file.") # read the concrete dag file and find the insertion marker with open(self.dag_file, "r") as f: contents = f.readlines() for i, line in enumerate(contents): if self.INSERTION_MARKER in line: insertion_pos = i + 1 break """ Inserts the path to the concrete json data file with open(os.path.join(os.path.dirname(Path(__file__)), 'json_file.json')) as f: payload = json.load(f) """ contents.insert( insertion_pos, f"with open(os.path.join(os.path.dirname(Path(__file__)), " f"'{ntpath.basename(self.json_file)}')) as f:\n" ) contents.insert(insertion_pos+1, f" payload = json.load(f)") # write back the modified contents to the concrete dag file with open(self.dag_file, "w") as f: f.writelines(contents) # [END insert_dynamic_data_to_dag] # [START generate_dag] def generate_dag(self): """ Generates a concrete dag file with its associated payload data in a concrete json file. Returns: a tuple containing the path to the dag_file and the path to the json_file """ self.logger.log(logging.DEBUG, "Generating the dag file.") self.remove_previous_versions() self.copy_dag_template_to_file() self.write_payload_to_file() self.insert_dynamic_data_to_dag() return {'dag_file': self.dag_file, 'json_file': self.json_file}
class DagValidator: """Class used to validate and inspect an Airflow DAG""" # gets the logger for this module logger = log_service.get_module_logger(__name__) # [START DagValidator constructor] def __init__(self, dag_file): # set class logger self.dag_file = dag_file # [END DagValidator constructor] # [START load_dag_module] def load_dag_module(self): """ Dynamically loads a concrete DAG file as a python module Returns: an instance of airflow.models.DAG """ self.logger.log(logging.DEBUG, "Loading the dag module.") module_name = os.path.splitext(self.dag_file)[0] self.logger.log( logging.INFO, f"Loading dag module name: {module_name} from dag file: {self.dag_file}" ) spec = importlib.util.spec_from_file_location(module_name, self.dag_file) dag_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(dag_module) return dag_module # [END load_dag_module] # [START assert_has_valid_dag] def assert_has_valid_dag(self): """ Assert that a module contains a valid DAG. Returns: a boolean indicating if a dag is found. True == found, False == not found """ self.logger.log(logging.DEBUG, "Asserting if the provided dag is valid.") dag_module = self.load_dag_module() no_dag_found = True for dag in vars(dag_module).values(): if isinstance(dag, models.DAG): self.logger.log(logging.INFO, f"{dag_module} is a DAG instance") no_dag_found = False dag.test_cycle() # Throws if a task cycle is found. if no_dag_found: raise AssertionError( f"DAG file {self.dag_file} does not contain a valid DAG") return no_dag_found # [END assert_has_valid_dag] # [START validate_dag] def validate_dag(self): """Verifies that a concrete DAG file is valid.""" self.logger.log(logging.DEBUG, "Validating if the provided dag is valid.") self.assert_has_valid_dag() # [START validate_dag] # [START safe_serialize] def safe_serialize(self, obj): """ Safely serialize a Python object to json - even when there are non-serializable attributes. Args: obj (Object): the object to be serialized Returns: a json string representing the serialized object """ self.logger.log(logging.DEBUG, "Safely serializing the provided dag.") return json.dumps( obj, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>", indent=4) # [END safe_serialize] # [START inspect_dag] def inspect_dag(self): """ Loads a DAG, validates the DAG and then dumps the DAG structure to a JSON string. Returns: an instance of the validated dag represented as a json string """ self.logger.log(logging.DEBUG, "Inspecting the provided dag.") dag_module = self.load_dag_module() no_dag_found = True dag_as_str = "" for dag in vars(dag_module).values(): if isinstance(dag, models.DAG): self.logger.log(logging.INFO, f"{dag_module} is a DAG instance") no_dag_found = False # put the dag details into a dictionary dag_dict = vars(dag) # iterate the task details and add to array tasks = [] for task in dag.tasks: tasks.append(vars(task)) # inject the task details array into the dag details dictionary dag_dict['tasks_details'] = tasks # safely serialize the dict to json dag_as_str = self.safe_serialize(dag_dict) if no_dag_found: raise AssertionError( f"DAG file {self.dag_file} does not contain a valid DAG") return dag_as_str
class AirflowService: """Class to interact with GCP Cloud Composer (Apache Airflow)""" IAM_SCOPE = 'https://www.googleapis.com/auth/iam' OAUTH_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token' # gets the logger for this module logger = log_service.get_module_logger(__name__) # [START AirflowService constructor] def __init__(self, authenticated_session, project_id, location, composer_environment): """ AirflowService constructor. Args: authenticated_session (google.auth.transport.requests.AuthorizedSession): A GCP authenticated session project_id (string): GCP Project Id of the Cloud Composer instance location (string): GCP Zone of the Cloud Composer instance composer_environment (string): Name of the Cloud Composer instance """ self.authenticated_session = authenticated_session self.project_id = project_id self.location = location self.composer_environment = composer_environment # [END AirflowService constructor] # [START get_airflow_config] def get_airflow_config(self): """ Gets the details of the Cloud Composer environment Returns: a dictionary containing the details of the Cloud Composer environment """ self.logger.log(logging.DEBUG, "Entered get_airflow_config method") environment_url = ( 'https://composer.googleapis.com/v1beta1/projects/{}/locations/{}' '/environments/{}').format(self.project_id, self.location, self.composer_environment) self.logger.log(logging.DEBUG, f"Cloud Composer environment URL: {environment_url}") composer_response = self.authenticated_session.request( 'GET', environment_url) environment_data = composer_response.json() airflow_uri = environment_data['config']['airflowUri'] # The Composer environment response does not include the IAP client ID. # Make a second, unauthenticated HTTP request to the web server to get the # redirect URI. redirect_response = requests.get(airflow_uri, allow_redirects=False) redirect_location = redirect_response.headers['location'] # Extract the client_id query parameter from the redirect. parsed = six.moves.urllib.parse.urlparse(redirect_location) query_string = six.moves.urllib.parse.parse_qs(parsed.query) environment_data['query_string'] = query_string return environment_data # [END get_airflow_config] # [START get_airflow_experimental_api] def get_airflow_experimental_api(self): """ Gets the details of the Cloud Composer experimental API Returns: a tuple containing the airflow experimental api uri path and airflow client id """ self.logger.log(logging.DEBUG, "Entered get_airflow_experimental_api method") environment_data = self.get_airflow_config() airflow_uri = environment_data['config']['airflowUri'] client_id = environment_data['query_string']['client_id'][0] return f"{airflow_uri}/api/experimental", client_id # [END get_airflow_experimental_api] # [START get_airflow_dag_gcs] def get_airflow_dag_gcs(self): """ Gets the Google Cloud Storage path for the dag files in the Cloud Composer environment Returns: the name of the Cloud Composer Google Cloud Storage dag file bucket """ self.logger.log(logging.DEBUG, "Entered get_airflow_dag_gcs method") environment_data = self.get_airflow_config() self.logger.log( logging.INFO, f"Google Cloud Storage path for the dag files: {environment_data['config']['dagGcsPrefix']}" ) return environment_data['config']['dagGcsPrefix'] # [END get_airflow_dag_gcs] # [START trigger_dag] def trigger_dag(self, dag_name, airflow_uri, client_id, data=None): """ Makes a POST request to the Cloud Composer experimental API to trigger a DAG Returns: the page body, or raises an exception if the page couldn't be retrieved. """ self.logger.log(logging.DEBUG, "Entered trigger_dag method") self.logger.log( logging.DEBUG, f"trigger_dag method params. dag_name: {dag_name}, airflow_uri: {airflow_uri}, client_id: {client_id}, data: {data}" ) webserver_url = f"{airflow_uri}/dags/{dag_name}/dag_runs" self.logger.log(logging.INFO, f"Web server URL: {webserver_url}") # Make a POST request to IAP which then Triggers the DAG if data: return self.make_post_iap_request(webserver_url, client_id, {"conf": data}) else: return self.make_post_iap_request(webserver_url, client_id, {}) # [END trigger_dag] # [START make_post_iap_request] # This code is copied from # https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/iap/make_iap_request.py # START COPIED IAP CODE def make_post_iap_request(self, url, client_id, json, **kwargs): """Makes a POST request to an application protected by Identity-Aware Proxy. Args: url: The Identity-Aware Proxy-protected URL to fetch. client_id: The client ID used by Identity-Aware Proxy. json: A JSON payload containing any additional data to be included with the POST request. **kwargs: Any of the parameters defined for the request function: https://github.com/requests/requests/blob/master/requests/api.py If no timeout is provided, it is set to 90 by default. Returns: The page body, or raises an exception if the page couldn't be retrieved. """ self.logger.log(logging.DEBUG, "Entered make_post_iap_request method") self.logger.log( logging.DEBUG, f"make_post_iap_request method params. url: {url}, client_id: {client_id}, json: {json}" ) # Set the default timeout, if missing if 'timeout' not in kwargs: kwargs['timeout'] = 90 # Obtain an OpenID Connect (OIDC) token from metadata server or using service # account. # try to get the id token from the auth_service google_open_id_connect_token = auth_service.get_id_token( Request(), client_id) if google_open_id_connect_token is None: google_open_id_connect_token = id_token.fetch_id_token( Request(), client_id) # Fetch the Identity-Aware Proxy-protected URL, including an # Authorization header containing "Bearer " followed by a # Google-issued OpenID Connect token for the service account. resp = requests.request( 'POST', url, headers={ 'Authorization': 'Bearer {}'.format(google_open_id_connect_token), 'Content-Type': 'application/json' }, json=json, **kwargs) if resp.status_code == 403: raise Exception('Service account does not have permission to ' 'access the IAP-protected application.') elif resp.status_code != 200: raise Exception( 'Bad response from application: {!r} / {!r} / {!r}'.format( resp.status_code, resp.headers, resp.text)) else: return resp.text # END COPIED IAP CODE # [END make_post_iap_request] # [START make_get_iap_request] # This code is copied from # https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/iap/make_iap_request.py # START COPIED IAP CODE def make_get_iap_request(self, url, client_id, **kwargs): """Makes a GET request to an application protected by Identity-Aware Proxy. Args: url: The Identity-Aware Proxy-protected URL to fetch. client_id: The client ID used by Identity-Aware Proxy. **kwargs: Any of the parameters defined for the request function: https://github.com/requests/requests/blob/master/requests/api.py If no timeout is provided, it is set to 90 by default. Returns: The page body, or raises an exception if the page couldn't be retrieved. """ self.logger.log(logging.DEBUG, "Entered make_get_iap_request method") self.logger.log( logging.DEBUG, f"make_get_iap_request method params. url: {url}, client_id: {client_id}" ) # Set the default timeout, if missing if 'timeout' not in kwargs: kwargs['timeout'] = 90 # try to get the id token from the auth_service google_open_id_connect_token = auth_service.get_id_token( Request(), client_id) if google_open_id_connect_token is None: google_open_id_connect_token = id_token.fetch_id_token( Request(), client_id) resp = requests.request( 'GET', url, headers={ 'Authorization': 'Bearer {}'.format(google_open_id_connect_token) }, **kwargs) if resp.status_code == 403: raise Exception('Service account does not have permission to ' 'access the IAP-protected application.') elif resp.status_code != 200: raise Exception( 'Bad response from application: {!r} / {!r} / {!r}'.format( resp.status_code, resp.headers, resp.text)) else: return resp.text