def execute(self, context): cluster_id = self.config.get("CLUSTER", "CLUSTER_ID") aws_hook = AwsHook(self.aws_credentials_id) aws_credentials = aws_hook.get_credentials() logging.info(f"Creating redshift cluster {cluster_id}...") redshift = boto3.client( "redshift", region_name="us-west-2", aws_access_key_id=aws_credentials.access_key, aws_secret_access_key=aws_credentials.secret_key, ) try: response = redshift.create_cluster( ClusterIdentifier=cluster_id, ClusterType=self.config.get("CLUSTER", "CLUSTER_TYPE"), NodeType=self.config.get("CLUSTER", "NODE_TYPE"), NumberOfNodes=int(self.config.get("CLUSTER", "NUMBER_OF_NODES")), DBName=self.config.get("CLUSTER", "DB_NAME"), Port=int(self.config.get("CLUSTER", "DB_PORT")), MasterUsername=self.config.get("CLUSTER", "DB_USER"), MasterUserPassword=self.config.get("CLUSTER", "DB_PASSWORD"), IamRoles=[self.config.get("IAM_ROLE", "ARN")], ) logging.info(f"{cluster_id} created") logging.info(response) except redshift.exceptions.ClusterAlreadyExistsFault: logging.info(f"{cluster_id} already exists") except Exception as e: raise AirflowFailException(e)
def extract_nyt_best_sellers( date: str = 'current', list_name: list = ['combined-print-and-e-book-fiction', 'combined-print-and-e-book-nonfiction']) -> dict : ''' Extracts New York Times best seller list information for combined-print-and-e-book-fiction and combined-print-and-e-book-nonfiction and returns list of results. ''' requestHeaders = { "Accept": "application/json" } nyt_best_sellers = [] date = Variable.get('next_update_date') api_key = Variable.get('NYT_API_KEY') for list_ in list_name: url = f'https://api.nytimes.com/svc/books/v3/lists/{date}/{list_}.json?api-key={api_key}' data = requests.get(url, requestHeaders) if not data: raise AirflowFailException('Date or Data Point not available') else: nyt_best_sellers.append(data.json()['results']) #If request succeeds -- than update the next_update_date variable next_update_date = datetime.strptime(date, '%Y-%m-%d') next_update_date += timedelta(days=7) next_update_date = next_update_date.strftime('%Y-%m-%d') Variable.set('next_update_date', next_update_date) return nyt_best_sellers
def execute(self, context: "Context") -> List[str]: hook = GoogleDriveHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) remote_file_ids = [] for local_path in self.local_paths: self.log.info("Uploading file to Google Drive: %s", local_path) try: remote_file_id = hook.upload_file( local_location=str(local_path), remote_location=str(Path(self.drive_folder) / Path(local_path).name), chunk_size=self.chunk_size, resumable=self.resumable, ) remote_file_ids.append(remote_file_id) if self.delete: os.remove(local_path) self.log.info("Deleted local file: %s", local_path) except FileNotFoundError: self.log.warning("File can't be found: %s", local_path) except OSError: self.log.warning("An OSError occurred for file: %s", local_path) if not self.ignore_if_missing and len(remote_file_ids) < len(self.local_paths): raise AirflowFailException("Some files couldn't be uploaded") return remote_file_ids
def create_iam_role(**kwargs): assume_role_policy_document = { "Version": "2012-10-17", "Statement": [{ "Effect": "Allow", "Principal": { "Service": "personalize.amazonaws.com" }, "Action": "sts:AssumeRole" }] } try: create_role_response = iam.create_role( RoleName=PERSONALIZE_ROLE_NAME, AssumeRolePolicyDocument=json.dumps(assume_role_policy_document)) iam.attach_role_policy( RoleName=PERSONALIZE_ROLE_NAME, PolicyArn="arn:aws:iam::aws:policy/AmazonS3ReadOnlyAccess") role_arn = create_role_response["Role"]["Arn"] # sometimes need to wait a bit for the role to be created time.sleep(30) return role_arn except ClientError as e: if e.response['Error']['Code'] == 'EntityAlreadyExists': role_arn = iam.get_role( RoleName=PERSONALIZE_ROLE_NAME)['Role']['Arn'] time.sleep(30) return role_arn else: raise AirflowFailException(f"PersonalizeS3Role create failed")
def trigger_loop(self, submission, task_list, mdata): # loop_num = None c = Client(None, None) submission_dict = submission.serialize() task_dict_list = [task.serialize() for task in task_list] submission_hash = submission.submission_hash self.run_id = f"dag_run_{submission_hash}" try: c.trigger_dag(dag_id=self.dag_name, run_id=self.run_id, conf={ 'submission_dict': submission_dict, 'task_dict_list': task_dict_list, 'mdata': mdata }) except DagRunAlreadyExists: dag_run_state = self.get_dag_run_state() if dag_run_state == State.FAILED: raise AirflowFailException( f"subdag dag_run fail dag_id:{self.dag_name}; run_id:{self.run_id};" ) else: print(f"continue from old dag_run {self.run_id}") loop_return = self.wait_until_end() return loop_return
def run_soda_scan(warehouse_yml_file, scan_yml_file): scan_builder = ScanBuilder() scan_builder.warehouse_yml_file = warehouse_yml_file scan_builder.scan_yml_file = scan_yml_file scan = scan_builder.build() scan_result = scan.execute() if scan_result.has_test_failures(): failures = scan_result.get_test_failures_count() raise AirflowFailException(f"Soda Scan found {failures} errors in your data!")
def import_dataset(**kwargs): ti = kwargs['ti'] interactions_dataset_arn = ti.xcom_pull(key="return_value", task_ids='create_dataset_type') role_arn = ti.xcom_pull(key="return_value", task_ids='create_iam_role') if not interactions_dataset_arn: dataset_group_arn = ti.xcom_pull(key="dataset_group_arn", task_ids='check_dataset_group') list_datasets_response = personalize.list_datasets( datasetGroupArn=dataset_group_arn, maxResults=100) interaction_dataset = next( (dataset for dataset in list_datasets_response["datasets"] if dataset["name"] == INTERACTION_DATASET_NAME), False) interactions_dataset_arn = interaction_dataset["datasetArn"] if not role_arn: role_arn = iam.get_role(RoleName=PERSONALIZE_ROLE_NAME)['Role']['Arn'] time.sleep(30) create_dataset_import_job_response = personalize.create_dataset_import_job( jobName="DEMO-dataset-import-job-" + suffix, datasetArn=interactions_dataset_arn, dataSource={"dataLocation": OUTPUT_PATH}, roleArn=role_arn) dataset_import_job_arn = create_dataset_import_job_response[ 'datasetImportJobArn'] logger.info(json.dumps(create_dataset_import_job_response, indent=2)) status = None max_time = time.time() + 2 * 60 * 60 # 2 hours while time.time() < max_time: describe_dataset_import_job_response = personalize.describe_dataset_import_job( datasetImportJobArn=dataset_import_job_arn) dataset_import_job = describe_dataset_import_job_response[ "datasetImportJob"] if "latestDatasetImportJobRun" not in dataset_import_job: status = dataset_import_job["status"] logger.info("DatasetImportJob: {}".format(status)) else: status = dataset_import_job["latestDatasetImportJobRun"]["status"] logger.info("LatestDatasetImportJobRun: {}".format(status)) if status == "ACTIVE" or status == "CREATE FAILED": break time.sleep(60) if status == "ACTIVE": return dataset_import_job_arn if status == "CREATE FAILED": raise AirflowFailException(f"Dataset import job create failed")
def update_solution(**kwargs): recipe_arn = "arn:aws:personalize:::recipe/aws-popularity-count" ti = kwargs['ti'] dataset_group_arn = ti.xcom_pull(key="dataset_group_arn", task_ids='check_dataset_group') if not dataset_group_arn: dataset_group_arn = ti.xcom_pull(key="dataset_group_arn", task_ids='create_dataset_group') list_solutions_response = personalize.list_solutions( datasetGroupArn=dataset_group_arn, maxResults=100) demo_solution = next((solution for solution in list_solutions_response["solutions"] if solution["name"] == SOLUTION_NAME), False) if not demo_solution: create_solution_response = personalize.create_solution( name=SOLUTION_NAME, datasetGroupArn=dataset_group_arn, recipeArn=recipe_arn) solution_arn = create_solution_response['solutionArn'] logger.info(json.dumps(create_solution_response, indent=2)) else: solution_arn = demo_solution["solutionArn"] kwargs['ti'].xcom_push(key="solution_arn", value=solution_arn) create_solution_version_response = personalize.create_solution_version( solutionArn=solution_arn, trainingMode='FULL') solution_version_arn = create_solution_version_response[ 'solutionVersionArn'] status = None max_time = time.time() + 2 * 60 * 60 # 2 hours while time.time() < max_time: describe_solution_version_response = personalize.describe_solution_version( solutionVersionArn=solution_version_arn) status = describe_solution_version_response["solutionVersion"][ "status"] logger.info(f"SolutionVersion: {status}") if status == "ACTIVE" or status == "CREATE FAILED": break time.sleep(60) if status == "ACTIVE": return solution_version_arn if status == "CREATE FAILED": raise AirflowFailException(f"Solution version create failed")
def _check_balances_loaded(balance_date: pendulum.Date, session=None): TI = TaskInstance count = session.query(func.count()).filter( TI.dag_id == lz_alpha_public.DAG_ID, TI.task_id == lz_alpha_public.TaskId.BALANCES, TI.state.in_(['success', 'skipped']), TI.execution_date == pendulum.Pendulum.create(balance_date.year, balance_date.month, balance_date.day), ).scalar() print(f"count={count}") if count == 0: raise AirflowFailException()
def wait_until_end(self): while True: dag_run_state = self.get_dag_run_state() if dag_run_state == State.SUCCESS: print(f"dag_run_state: {dag_run_state}") break elif dag_run_state == State.RUNNING: print(f"dag_run_state: {dag_run_state}") time.sleep(30) else: raise AirflowFailException( f"subdag dag_run fail dag_id:{self.dag_name}; run_id:{self.run_id};" ) return dag_run_state
def execute(self, context): postgres_hook = PostgresHook(self.redshift_conn_id) query_template = """ select count(*) from {table} where {filters}; """ for (table, cols) in self.checks: filters = " or ".join([f"{col} is null" for col in cols]) table_name = f'stage."{table}_{context["ds"]}"' query = query_template.format(table=table_name, filters=filters) logging.info(query) results = postgres_hook.get_records(query) if results is None: raise AirflowFailException( f"Quality check did not return any results: {query.strip}" ) elif results[0][0] != 0: raise AirflowFailException( f"{results[0][0]} rows with disallowed NULLs in {table}" ) else: logging.info(f"NULL check passed for {table}")
def transform_cod_shp(input_filename, adm_level, layer_name=None) -> gpd.GeoDataFrame: logger.info( f"Transform Cod Shp Arguments:\n" f"Input Filename = {input_filename}" f"Admin Level = {adm_level}" f"Layer Name = {layer_name}" ) if layer_name is None: with zipfile.ZipFile(input_filename) as z: for filename in z.namelist(): if adm_level in filename.lower() and filename.lower().endswith(".shp"): layer_name = filename if layer_name is None: raise AirflowFailException( f"There is no {adm_level} layer in the input dataset!" ) return gpd.read_file(f"zip://{input_filename}!{layer_name}")
def update_campaign(**kwargs): ti = kwargs['ti'] solution_version_arn = ti.xcom_pull(key="return_value", task_ids='update_solution') solution_arn = ti.xcom_pull(key="solution_arn", task_ids='update_solution') list_campaigns_response = personalize.list_campaigns( solutionArn=solution_arn, maxResults=100) demo_campaign = next((campaign for campaign in list_campaigns_response["campaigns"] if campaign["name"] == CAMPAIGN_NAME), False) if not demo_campaign: create_campaign_response = personalize.create_campaign( name=CAMPAIGN_NAME, solutionVersionArn=solution_version_arn, minProvisionedTPS=2, ) campaign_arn = create_campaign_response['campaignArn'] logger.info(json.dumps(create_campaign_response, indent=2)) else: campaign_arn = demo_campaign["campaignArn"] personalize.update_campaign(campaignArn=campaign_arn, solutionVersionArn=solution_version_arn, minProvisionedTPS=2) status = None max_time = time.time() + 2 * 60 * 60 # 2 hours while time.time() < max_time: describe_campaign_response = personalize.describe_campaign( campaignArn=campaign_arn) status = describe_campaign_response["campaign"]["status"] print("Campaign: {}".format(status)) if status == "ACTIVE" or status == "CREATE FAILED": break time.sleep(60) if status == "ACTIVE": return campaign_arn if status == "CREATE FAILED": raise AirflowFailException(f"Campaign create/update failed")
def airflow_error(job_state: JobState, name: str, job_id: str): """Throw an error on a terminal event if job errored out :param job_state: A JobState enum class :param name: The name of your armada job :param job_id: The job id that armada assigns to it :return: No Return or an AirflowFailException. AirflowFailException tells Airflow Schedule to not reschedule the task """ if job_state == JobState.SUCCEEDED: return if ( job_state == JobState.FAILED or job_state == JobState.CANCELLED or job_state == JobState.JOB_ID_NOT_FOUND ): job_message = job_state.name raise AirflowFailException(f"The Armada job {name}:{job_id} {job_message}")
def execute(self, context): postgres_hook = PostgresHook(self.redshift_conn_id) query_template = """ select {group_cols}, count(*) as count from {table} group by {group_cols} having count > 1; """ for (table, cols) in self.checks: group_cols = ", ".join(cols) table_name = f'stage."{table}_{context["ds"]}"' query = query_template.format(table=table_name, group_cols=group_cols) logging.info(query) results = postgres_hook.get_records(query) if len(results) == 0: logging.info(f"Unique column check passed for {table}") else: raise AirflowFailException( f"Uniqueness failed for {table} for columns {group_cols}: {results}" )
def create_dataset_group(**kwargs): create_dg_response = personalize.create_dataset_group( name=DATASET_GROUP_NAME) dataset_group_arn = create_dg_response["datasetGroupArn"] status = None max_time = time.time() + 2 * 60 * 60 # 2 hours while time.time() < max_time: describe_dataset_group_response = personalize.describe_dataset_group( datasetGroupArn=dataset_group_arn) status = describe_dataset_group_response["datasetGroup"]["status"] logger.info(f"DatasetGroup: {status}") if status == "ACTIVE" or status == "CREATE FAILED": break time.sleep(20) if status == "ACTIVE": kwargs['ti'].xcom_push(key="dataset_group_arn", value=dataset_group_arn) if status == "CREATE FAILED": raise AirflowFailException( f"DatasetGroup {DATASET_GROUP_NAME} create failed")
def task_function(ti): raise AirflowFailException()
def __init__( self, source_table_name: Optional[str] = None, target_table_name: Optional[str] = None, truncate_target: bool = True, copy_data: bool = True, drop_source: bool = True, task_id: str = "copy_table", postgres_conn_id: str = "postgres_default", database: Optional[str] = None, autocommit: bool = False, xcom_task_ids: Optional[str] = None, xcom_attr_assigner: Callable[[Any, Any], None] = lambda o, x: None, xcom_key: str = XCOM_RETURN_KEY, *args: Any, **kwargs: Any, ) -> None: """Initialize PostgresTableCopyOperator. Args: source_table_name: Name of the temporary table that needs to be copied. target_table_name: Name of the table that needs to be created and data copied to. truncate_target: Whether the target table should be truncated before being copied to. copy_data: Whether to copied data. If set to False only the table definition is copied. drop_source: Whether to drop the source table after the copy process. task_id: Task ID postgres_conn_id: The PostgreSQL connection id. database: Name of the databas to use (if different from database from connection id) autocommit: What to set the connection's autocommit setting to. xcom_task_ids: The id of the task that is providing the xcom info. xcom_attr_assigner: Callable tha can be provided to assign new values to object attributes. xcom_key: Key use to grab the xcom info, defaults to the airflow default `return_value`. *args: **kwargs: """ super().__init__(*args, task_id=task_id, **kwargs) self.source_table_name = source_table_name self.target_table_name = target_table_name self.truncate_target = truncate_target self.copy_data = copy_data self.drop_source = drop_source self.postgres_conn_id = postgres_conn_id self.database = database self.autocommit = autocommit self.xcom_task_ids = xcom_task_ids self.xcom_attr_assigner = xcom_attr_assigner self.xcom_key = xcom_key # Some checks for valid values assert (source_table_name is None) is ( target_table_name is None), "Source and target should be both None or both provided." # Here we build on the previous assertion assert bool(source_table_name) ^ ( xcom_task_ids is not None ), "Either table names or xcom_task_ids should be provided." if not copy_data and drop_source: raise AirflowFailException( "Configuration error: source data will not be copied, " "even though source table will be dropped.")
def execute(self, context: Dict): super().execute(context) # self.log.info(self.message) raise AirflowFailException(self.message)
def apply_model_request(**kwargs): """Request for calling the apply model service""" ti = kwargs["ti"] dag_run = kwargs["dag_run"] dag_run_conf = dag_run.conf src_img_name = dag_run_conf["src_img_name"] db_tile_basename_list = ti.xcom_pull(task_ids="data_ingestion", key="tile_list") bucket = dag_run_conf["src_bucket_name"] model_id = ti.xcom_pull(task_ids="data_ingestion", key="model_id") service_url = os.environ["APPLYMODEL_URL"] base_url = f"http://{service_url}/applyModel/predictByModelId" logging.info(f"Processing tile : {db_tile_basename_list[0]}") req_params = urllib.parse.urlencode({ "src_bucket_name": bucket, "src_img_name": src_img_name, "db_tile_basename": db_tile_basename_list[0]["tile_name"], "model_id": model_id, "dst_bucket": "predictions", }) req_url = f"{base_url}?{req_params}" try: resp = requests.post(req_url, timeout=30 * 60) out = resp.json() status_code = resp.status_code except Exception as e: status_code = 400 out = {} out["type"] = str(type(e)) out["message"] = str(e) logging.info(f"out: {out}") if status_code == 200: logging.info( f"Success. Tile : {db_tile_basename_list[0]}, Status code : {status_code}" ) ti.xcom_push(key="img_bucket", value=out["bucket"]) ti.xcom_push(key="img_path", value=out["path"]) ti.xcom_push(key="db_tile_basename", value=db_tile_basename_list[0]["tile_name"]) else: error_type = out["type"] error_msg = out["message"] logging.warn( f"Prediction failed. Tile : {db_tile_basename_list[0]}, Status code : {status_code}, Error : {error_type}, \n Message : {error_msg}" ) # update db entry dag_run_id = dag_run.run_id service_url = os.environ["DATAMANAGEMENT_URL"] req_url = f"http://{service_url}/productionRuns/updateStatus" req_params = urllib.parse.urlencode({ "dag_run_id": dag_run_id, "status": 'failed' }) req_url = f"{req_url}?{req_params}" requests.post(req_url) raise AirflowFailException( f"Prediction failed. Tile : {db_tile_basename_list[0]}, Status code : {status_code}, Error : {error_type}, \n Message : {error_msg}" )
def crash_dag() -> None: raise AirflowFailException()