예제 #1
0
    def execute(self, context):
        cluster_id = self.config.get("CLUSTER", "CLUSTER_ID")

        aws_hook = AwsHook(self.aws_credentials_id)
        aws_credentials = aws_hook.get_credentials()

        logging.info(f"Creating redshift cluster {cluster_id}...")
        redshift = boto3.client(
            "redshift",
            region_name="us-west-2",
            aws_access_key_id=aws_credentials.access_key,
            aws_secret_access_key=aws_credentials.secret_key,
        )
        try:
            response = redshift.create_cluster(
                ClusterIdentifier=cluster_id,
                ClusterType=self.config.get("CLUSTER", "CLUSTER_TYPE"),
                NodeType=self.config.get("CLUSTER", "NODE_TYPE"),
                NumberOfNodes=int(self.config.get("CLUSTER",
                                                  "NUMBER_OF_NODES")),
                DBName=self.config.get("CLUSTER", "DB_NAME"),
                Port=int(self.config.get("CLUSTER", "DB_PORT")),
                MasterUsername=self.config.get("CLUSTER", "DB_USER"),
                MasterUserPassword=self.config.get("CLUSTER", "DB_PASSWORD"),
                IamRoles=[self.config.get("IAM_ROLE", "ARN")],
            )
            logging.info(f"{cluster_id} created")
            logging.info(response)
        except redshift.exceptions.ClusterAlreadyExistsFault:
            logging.info(f"{cluster_id} already exists")
        except Exception as e:
            raise AirflowFailException(e)
    def extract_nyt_best_sellers(
        date: str = 'current', 
        list_name: list = ['combined-print-and-e-book-fiction',
                        'combined-print-and-e-book-nonfiction']) -> dict :
        
        '''
        Extracts New York Times best seller list information for 
        combined-print-and-e-book-fiction and combined-print-and-e-book-nonfiction
        and returns list of results.
        '''
        requestHeaders = {
        "Accept": "application/json"
        }
        nyt_best_sellers = []
        date = Variable.get('next_update_date')
        api_key = Variable.get('NYT_API_KEY')
        for list_ in list_name:
            url = f'https://api.nytimes.com/svc/books/v3/lists/{date}/{list_}.json?api-key={api_key}'
            data = requests.get(url, requestHeaders)
            if not data:
                raise AirflowFailException('Date or Data Point not available')
            else:
                nyt_best_sellers.append(data.json()['results'])
        #If request succeeds -- than update the next_update_date variable
        next_update_date = datetime.strptime(date, '%Y-%m-%d')
        next_update_date += timedelta(days=7)
        next_update_date = next_update_date.strftime('%Y-%m-%d')
        Variable.set('next_update_date', next_update_date)

        return nyt_best_sellers
예제 #3
0
    def execute(self, context: "Context") -> List[str]:
        hook = GoogleDriveHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )

        remote_file_ids = []

        for local_path in self.local_paths:
            self.log.info("Uploading file to Google Drive: %s", local_path)

            try:
                remote_file_id = hook.upload_file(
                    local_location=str(local_path),
                    remote_location=str(Path(self.drive_folder) / Path(local_path).name),
                    chunk_size=self.chunk_size,
                    resumable=self.resumable,
                )

                remote_file_ids.append(remote_file_id)

                if self.delete:
                    os.remove(local_path)
                    self.log.info("Deleted local file: %s", local_path)
            except FileNotFoundError:
                self.log.warning("File can't be found: %s", local_path)
            except OSError:
                self.log.warning("An OSError occurred for file: %s", local_path)

        if not self.ignore_if_missing and len(remote_file_ids) < len(self.local_paths):
            raise AirflowFailException("Some files couldn't be uploaded")
        return remote_file_ids
예제 #4
0
def create_iam_role(**kwargs):
    assume_role_policy_document = {
        "Version":
        "2012-10-17",
        "Statement": [{
            "Effect": "Allow",
            "Principal": {
                "Service": "personalize.amazonaws.com"
            },
            "Action": "sts:AssumeRole"
        }]
    }
    try:
        create_role_response = iam.create_role(
            RoleName=PERSONALIZE_ROLE_NAME,
            AssumeRolePolicyDocument=json.dumps(assume_role_policy_document))

        iam.attach_role_policy(
            RoleName=PERSONALIZE_ROLE_NAME,
            PolicyArn="arn:aws:iam::aws:policy/AmazonS3ReadOnlyAccess")

        role_arn = create_role_response["Role"]["Arn"]

        # sometimes need to wait a bit for the role to be created
        time.sleep(30)
        return role_arn
    except ClientError as e:
        if e.response['Error']['Code'] == 'EntityAlreadyExists':
            role_arn = iam.get_role(
                RoleName=PERSONALIZE_ROLE_NAME)['Role']['Arn']
            time.sleep(30)
            return role_arn
        else:
            raise AirflowFailException(f"PersonalizeS3Role create failed")
예제 #5
0
 def trigger_loop(self, submission, task_list, mdata):
     # loop_num = None
     c = Client(None, None)
     submission_dict = submission.serialize()
     task_dict_list = [task.serialize() for task in task_list]
     submission_hash = submission.submission_hash
     self.run_id = f"dag_run_{submission_hash}"
     try:
         c.trigger_dag(dag_id=self.dag_name,
                       run_id=self.run_id,
                       conf={
                           'submission_dict': submission_dict,
                           'task_dict_list': task_dict_list,
                           'mdata': mdata
                       })
     except DagRunAlreadyExists:
         dag_run_state = self.get_dag_run_state()
         if dag_run_state == State.FAILED:
             raise AirflowFailException(
                 f"subdag dag_run fail dag_id:{self.dag_name}; run_id:{self.run_id};"
             )
         else:
             print(f"continue from old dag_run {self.run_id}")
     loop_return = self.wait_until_end()
     return loop_return
예제 #6
0
def run_soda_scan(warehouse_yml_file, scan_yml_file):
    scan_builder = ScanBuilder()
    scan_builder.warehouse_yml_file = warehouse_yml_file
    scan_builder.scan_yml_file = scan_yml_file
    scan = scan_builder.build()
    scan_result = scan.execute()
    if scan_result.has_test_failures():
        failures = scan_result.get_test_failures_count()
        raise AirflowFailException(f"Soda Scan found {failures} errors in your data!")
예제 #7
0
def import_dataset(**kwargs):
    ti = kwargs['ti']
    interactions_dataset_arn = ti.xcom_pull(key="return_value",
                                            task_ids='create_dataset_type')
    role_arn = ti.xcom_pull(key="return_value", task_ids='create_iam_role')

    if not interactions_dataset_arn:
        dataset_group_arn = ti.xcom_pull(key="dataset_group_arn",
                                         task_ids='check_dataset_group')

        list_datasets_response = personalize.list_datasets(
            datasetGroupArn=dataset_group_arn, maxResults=100)
        interaction_dataset = next(
            (dataset for dataset in list_datasets_response["datasets"]
             if dataset["name"] == INTERACTION_DATASET_NAME), False)

        interactions_dataset_arn = interaction_dataset["datasetArn"]

    if not role_arn:
        role_arn = iam.get_role(RoleName=PERSONALIZE_ROLE_NAME)['Role']['Arn']
        time.sleep(30)

    create_dataset_import_job_response = personalize.create_dataset_import_job(
        jobName="DEMO-dataset-import-job-" + suffix,
        datasetArn=interactions_dataset_arn,
        dataSource={"dataLocation": OUTPUT_PATH},
        roleArn=role_arn)

    dataset_import_job_arn = create_dataset_import_job_response[
        'datasetImportJobArn']
    logger.info(json.dumps(create_dataset_import_job_response, indent=2))

    status = None
    max_time = time.time() + 2 * 60 * 60  # 2 hours

    while time.time() < max_time:
        describe_dataset_import_job_response = personalize.describe_dataset_import_job(
            datasetImportJobArn=dataset_import_job_arn)

        dataset_import_job = describe_dataset_import_job_response[
            "datasetImportJob"]
        if "latestDatasetImportJobRun" not in dataset_import_job:
            status = dataset_import_job["status"]
            logger.info("DatasetImportJob: {}".format(status))
        else:
            status = dataset_import_job["latestDatasetImportJobRun"]["status"]
            logger.info("LatestDatasetImportJobRun: {}".format(status))

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(60)

    if status == "ACTIVE":
        return dataset_import_job_arn
    if status == "CREATE FAILED":
        raise AirflowFailException(f"Dataset import job create failed")
예제 #8
0
def update_solution(**kwargs):
    recipe_arn = "arn:aws:personalize:::recipe/aws-popularity-count"
    ti = kwargs['ti']

    dataset_group_arn = ti.xcom_pull(key="dataset_group_arn",
                                     task_ids='check_dataset_group')
    if not dataset_group_arn:
        dataset_group_arn = ti.xcom_pull(key="dataset_group_arn",
                                         task_ids='create_dataset_group')

    list_solutions_response = personalize.list_solutions(
        datasetGroupArn=dataset_group_arn, maxResults=100)

    demo_solution = next((solution
                          for solution in list_solutions_response["solutions"]
                          if solution["name"] == SOLUTION_NAME), False)

    if not demo_solution:
        create_solution_response = personalize.create_solution(
            name=SOLUTION_NAME,
            datasetGroupArn=dataset_group_arn,
            recipeArn=recipe_arn)

        solution_arn = create_solution_response['solutionArn']
        logger.info(json.dumps(create_solution_response, indent=2))
    else:
        solution_arn = demo_solution["solutionArn"]

    kwargs['ti'].xcom_push(key="solution_arn", value=solution_arn)
    create_solution_version_response = personalize.create_solution_version(
        solutionArn=solution_arn, trainingMode='FULL')

    solution_version_arn = create_solution_version_response[
        'solutionVersionArn']

    status = None
    max_time = time.time() + 2 * 60 * 60  # 2 hours
    while time.time() < max_time:
        describe_solution_version_response = personalize.describe_solution_version(
            solutionVersionArn=solution_version_arn)
        status = describe_solution_version_response["solutionVersion"][
            "status"]
        logger.info(f"SolutionVersion: {status}")

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(60)

    if status == "ACTIVE":
        return solution_version_arn
    if status == "CREATE FAILED":
        raise AirflowFailException(f"Solution version create failed")
예제 #9
0
def _check_balances_loaded(balance_date: pendulum.Date, session=None):
    TI = TaskInstance
    count = session.query(func.count()).filter(
        TI.dag_id == lz_alpha_public.DAG_ID,
        TI.task_id == lz_alpha_public.TaskId.BALANCES,
        TI.state.in_(['success', 'skipped']),
        TI.execution_date == pendulum.Pendulum.create(balance_date.year,
                                                      balance_date.month,
                                                      balance_date.day),
    ).scalar()
    print(f"count={count}")
    if count == 0:
        raise AirflowFailException()
예제 #10
0
 def wait_until_end(self):
     while True:
         dag_run_state = self.get_dag_run_state()
         if dag_run_state == State.SUCCESS:
             print(f"dag_run_state: {dag_run_state}")
             break
         elif dag_run_state == State.RUNNING:
             print(f"dag_run_state: {dag_run_state}")
             time.sleep(30)
         else:
             raise AirflowFailException(
                 f"subdag dag_run fail dag_id:{self.dag_name}; run_id:{self.run_id};"
             )
     return dag_run_state
    def execute(self, context):
        postgres_hook = PostgresHook(self.redshift_conn_id)
        query_template = """
            select count(*)
            from {table}
            where {filters};
        """

        for (table, cols) in self.checks:
            filters = " or ".join([f"{col} is null" for col in cols])
            table_name = f'stage."{table}_{context["ds"]}"'
            query = query_template.format(table=table_name, filters=filters)
            logging.info(query)
            results = postgres_hook.get_records(query)
            if results is None:
                raise AirflowFailException(
                    f"Quality check did not return any results: {query.strip}"
                )
            elif results[0][0] != 0:
                raise AirflowFailException(
                    f"{results[0][0]} rows with disallowed NULLs in {table}"
                )
            else:
                logging.info(f"NULL check passed for {table}")
예제 #12
0
def transform_cod_shp(input_filename, adm_level, layer_name=None) -> gpd.GeoDataFrame:
    logger.info(
        f"Transform Cod Shp Arguments:\n"
        f"Input Filename = {input_filename}"
        f"Admin Level = {adm_level}"
        f"Layer Name = {layer_name}"
    )
    if layer_name is None:
        with zipfile.ZipFile(input_filename) as z:
            for filename in z.namelist():
                if adm_level in filename.lower() and filename.lower().endswith(".shp"):
                    layer_name = filename
    if layer_name is None:
        raise AirflowFailException(
            f"There is no {adm_level} layer in the input dataset!"
        )
    return gpd.read_file(f"zip://{input_filename}!{layer_name}")
예제 #13
0
def update_campaign(**kwargs):
    ti = kwargs['ti']
    solution_version_arn = ti.xcom_pull(key="return_value",
                                        task_ids='update_solution')
    solution_arn = ti.xcom_pull(key="solution_arn", task_ids='update_solution')

    list_campaigns_response = personalize.list_campaigns(
        solutionArn=solution_arn, maxResults=100)

    demo_campaign = next((campaign
                          for campaign in list_campaigns_response["campaigns"]
                          if campaign["name"] == CAMPAIGN_NAME), False)
    if not demo_campaign:
        create_campaign_response = personalize.create_campaign(
            name=CAMPAIGN_NAME,
            solutionVersionArn=solution_version_arn,
            minProvisionedTPS=2,
        )

        campaign_arn = create_campaign_response['campaignArn']
        logger.info(json.dumps(create_campaign_response, indent=2))
    else:
        campaign_arn = demo_campaign["campaignArn"]
        personalize.update_campaign(campaignArn=campaign_arn,
                                    solutionVersionArn=solution_version_arn,
                                    minProvisionedTPS=2)

    status = None
    max_time = time.time() + 2 * 60 * 60  # 2 hours
    while time.time() < max_time:
        describe_campaign_response = personalize.describe_campaign(
            campaignArn=campaign_arn)
        status = describe_campaign_response["campaign"]["status"]
        print("Campaign: {}".format(status))

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(60)

    if status == "ACTIVE":
        return campaign_arn
    if status == "CREATE FAILED":
        raise AirflowFailException(f"Campaign create/update failed")
예제 #14
0
def airflow_error(job_state: JobState, name: str, job_id: str):
    """Throw an error on a terminal event if job errored out

    :param job_state: A JobState enum class
    :param name: The name of your armada job
    :param job_id: The job id that armada assigns to it
    :return: No Return or an AirflowFailException.

    AirflowFailException tells Airflow Schedule to not reschedule the task

    """
    if job_state == JobState.SUCCEEDED:
        return
    if (
        job_state == JobState.FAILED
        or job_state == JobState.CANCELLED
        or job_state == JobState.JOB_ID_NOT_FOUND
    ):
        job_message = job_state.name
        raise AirflowFailException(f"The Armada job {name}:{job_id} {job_message}")
    def execute(self, context):
        postgres_hook = PostgresHook(self.redshift_conn_id)
        query_template = """
            select {group_cols}, count(*) as count
            from {table}
            group by {group_cols}
            having count > 1;
        """

        for (table, cols) in self.checks:
            group_cols = ", ".join(cols)
            table_name = f'stage."{table}_{context["ds"]}"'
            query = query_template.format(table=table_name,
                                          group_cols=group_cols)
            logging.info(query)
            results = postgres_hook.get_records(query)
            if len(results) == 0:
                logging.info(f"Unique column check passed for {table}")
            else:
                raise AirflowFailException(
                    f"Uniqueness failed for {table} for columns {group_cols}: {results}"
                )
예제 #16
0
def create_dataset_group(**kwargs):
    create_dg_response = personalize.create_dataset_group(
        name=DATASET_GROUP_NAME)
    dataset_group_arn = create_dg_response["datasetGroupArn"]

    status = None
    max_time = time.time() + 2 * 60 * 60  # 2 hours
    while time.time() < max_time:
        describe_dataset_group_response = personalize.describe_dataset_group(
            datasetGroupArn=dataset_group_arn)
        status = describe_dataset_group_response["datasetGroup"]["status"]
        logger.info(f"DatasetGroup: {status}")

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(20)
    if status == "ACTIVE":
        kwargs['ti'].xcom_push(key="dataset_group_arn",
                               value=dataset_group_arn)
    if status == "CREATE FAILED":
        raise AirflowFailException(
            f"DatasetGroup {DATASET_GROUP_NAME} create failed")
예제 #17
0
 def task_function(ti):
     raise AirflowFailException()
    def __init__(
        self,
        source_table_name: Optional[str] = None,
        target_table_name: Optional[str] = None,
        truncate_target: bool = True,
        copy_data: bool = True,
        drop_source: bool = True,
        task_id: str = "copy_table",
        postgres_conn_id: str = "postgres_default",
        database: Optional[str] = None,
        autocommit: bool = False,
        xcom_task_ids: Optional[str] = None,
        xcom_attr_assigner: Callable[[Any, Any], None] = lambda o, x: None,
        xcom_key: str = XCOM_RETURN_KEY,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Initialize PostgresTableCopyOperator.

        Args:
            source_table_name: Name of the temporary table that needs to be copied.
            target_table_name: Name of the table that needs to be created and data copied to.
            truncate_target: Whether the target table should be truncated before being copied to.
            copy_data: Whether to copied data. If set to False only the table definition is copied.
            drop_source: Whether to drop the source table after the copy process.
            task_id: Task ID
            postgres_conn_id: The PostgreSQL connection id.
            database: Name of the databas to use (if different from database from connection id)
            autocommit: What to set the connection's autocommit setting to.
            xcom_task_ids: The id of the task that is providing the xcom info.
            xcom_attr_assigner: Callable tha can be provided to assign new values
                to object attributes.
            xcom_key: Key use to grab the xcom info, defaults to the airflow
                default `return_value`.
            *args:
            **kwargs:
        """
        super().__init__(*args, task_id=task_id, **kwargs)
        self.source_table_name = source_table_name
        self.target_table_name = target_table_name
        self.truncate_target = truncate_target
        self.copy_data = copy_data
        self.drop_source = drop_source
        self.postgres_conn_id = postgres_conn_id
        self.database = database
        self.autocommit = autocommit

        self.xcom_task_ids = xcom_task_ids
        self.xcom_attr_assigner = xcom_attr_assigner
        self.xcom_key = xcom_key

        # Some checks for valid values
        assert (source_table_name is None) is (
            target_table_name is
            None), "Source and target should be both None or both provided."

        # Here we build on the previous assertion
        assert bool(source_table_name) ^ (
            xcom_task_ids is not None
        ), "Either table names or xcom_task_ids should be provided."

        if not copy_data and drop_source:
            raise AirflowFailException(
                "Configuration error: source data will not be copied, "
                "even though source table will be dropped.")
예제 #19
0
 def execute(self, context: Dict):
     super().execute(context)
     # self.log.info(self.message)
     raise AirflowFailException(self.message)
def apply_model_request(**kwargs):
    """Request for calling the apply model service"""
    ti = kwargs["ti"]
    dag_run = kwargs["dag_run"]
    dag_run_conf = dag_run.conf
    src_img_name = dag_run_conf["src_img_name"]
    db_tile_basename_list = ti.xcom_pull(task_ids="data_ingestion",
                                         key="tile_list")
    bucket = dag_run_conf["src_bucket_name"]
    model_id = ti.xcom_pull(task_ids="data_ingestion", key="model_id")
    service_url = os.environ["APPLYMODEL_URL"]
    base_url = f"http://{service_url}/applyModel/predictByModelId"
    logging.info(f"Processing tile : {db_tile_basename_list[0]}")
    req_params = urllib.parse.urlencode({
        "src_bucket_name":
        bucket,
        "src_img_name":
        src_img_name,
        "db_tile_basename":
        db_tile_basename_list[0]["tile_name"],
        "model_id":
        model_id,
        "dst_bucket":
        "predictions",
    })
    req_url = f"{base_url}?{req_params}"
    try:
        resp = requests.post(req_url, timeout=30 * 60)
        out = resp.json()
        status_code = resp.status_code
    except Exception as e:
        status_code = 400
        out = {}
        out["type"] = str(type(e))
        out["message"] = str(e)

    logging.info(f"out: {out}")
    if status_code == 200:
        logging.info(
            f"Success. Tile : {db_tile_basename_list[0]}, Status code : {status_code}"
        )
        ti.xcom_push(key="img_bucket", value=out["bucket"])
        ti.xcom_push(key="img_path", value=out["path"])
        ti.xcom_push(key="db_tile_basename",
                     value=db_tile_basename_list[0]["tile_name"])
    else:
        error_type = out["type"]
        error_msg = out["message"]
        logging.warn(
            f"Prediction failed. Tile : {db_tile_basename_list[0]}, Status code : {status_code}, Error : {error_type}, \n Message : {error_msg}"
        )
        # update db entry
        dag_run_id = dag_run.run_id
        service_url = os.environ["DATAMANAGEMENT_URL"]
        req_url = f"http://{service_url}/productionRuns/updateStatus"
        req_params = urllib.parse.urlencode({
            "dag_run_id": dag_run_id,
            "status": 'failed'
        })
        req_url = f"{req_url}?{req_params}"
        requests.post(req_url)
        raise AirflowFailException(
            f"Prediction failed. Tile : {db_tile_basename_list[0]}, Status code : {status_code}, Error : {error_type}, \n Message : {error_msg}"
        )
예제 #21
0
def crash_dag() -> None:
    raise AirflowFailException()