示例#1
0
def PodOperator(*args, **kwargs):
    # TODO: tune this, and add resource limits
    namespace = kwargs.pop("namespace", "default")

    is_gke = kwargs.pop("is_gke", False)  # we want to always pop()

    if "secrets" in kwargs:
        kwargs["secrets"] = map(lambda d: Secret(**d), kwargs["secrets"])

    if is_development() or is_gke:
        return GKEPodOperator(
            *args,
            in_cluster=False,
            project_id=
            "cal-itp-data-infra",  # there currently isn't a staging cluster
            location=kwargs.pop("pod_location", os.environ["POD_LOCATION"]),
            cluster_name=kwargs.pop("cluster_name",
                                    os.environ["POD_CLUSTER_NAME"]),
            namespace=namespace,
            image_pull_policy="Always" if is_development() else "IfNotPresent",
            **kwargs,
        )

    else:
        return KubernetesPodOperator(*args, namespace=namespace, **kwargs)
    def execute(self, context):
        # use the DAG's logical date as the data interval start,
        # and ensure the 'start' hour is 0 no matter what the 'schedule_interval' is.
        start_datetime = context.get("execution_date").set(hour=0)

        # add 23 hours to the start date to make the total range equal to 24 hours.
        # (the 'end' parameter is inclusive: https://developers.amplitude.com/docs/export-api#export-api-parameters)
        start = start_datetime.strftime(DATE_FORMAT)
        end = (start_datetime + timedelta(hours=23)).strftime(DATE_FORMAT)

        events_df = amplitude_to_df(
            start,
            end,
            api_key_env=self.api_key_env,
            secret_key_env=self.secret_key_env,
            rename_fields=self.rename_fields,
        )

        events_jsonl = events_df.to_json(orient="records",
                                         lines=True,
                                         date_format="iso")
        gcs_file_path = f"{self.app_name}/{start}-{end}.jsonl"

        bucket_name = ("ingest_amplitude_raw_dev"
                       if is_development() else "ingest_amplitude_raw_prod")

        # if a file already exists at `gcs_file_path`, GCS will overwrite the existing file
        calitp.save_to_gcfs(events_jsonl.encode(),
                            gcs_file_path,
                            bucket=bucket_name,
                            use_pipe=True)
示例#3
0
def email_failures(task_instance, ds, **kwargs):
    if is_development():
        print("Skipping since in development mode!")
        return

    status = task_instance.xcom_pull(task_ids="download_data")
    error_agencies = status["errors"]

    if len(error_agencies) > 0:
        html_report = pd.DataFrame(error_agencies).to_html(border=False)

        html_content = f"""\
The following agency GTFS feeds could not be extracted on {ds}:

{html_report}
"""
    else:
        html_content = "All feeds were downloaded successfully!"

    send_email(
        to=[
            "*****@*****.**",
            "*****@*****.**",
            "*****@*****.**",
            "*****@*****.**",
            "*****@*****.**",
            "*****@*****.**",
            "*****@*****.**",
        ],
        html_content=html_content,
        subject=
        (f"Operator GTFS Errors for {datetime.datetime.now().strftime('%Y-%m-%d')}"
         ),
    )
def gen_list(execution_date, **kwargs):
    """
    task callable to generate the list and push into
    xcom
    """

    # get a table of feed urls from agencies.yml
    # we fetch both the raw and filled w/ API key versions to save
    filled_agencies_file = ("data/agencies.filled.yml"
                            if is_development() else "data/agencies.yml")
    feeds_raw = make_gtfs_list(pipe_file_name("data/agencies_raw.yml"))
    feeds = make_gtfs_list(pipe_file_name(filled_agencies_file))

    path_metadata = f"schedule/{execution_date}/metadata"

    save_to_gcfs(
        feeds_raw.to_csv(index=False).encode(),
        f"{path_metadata}/feeds_raw.csv",
        use_pipe=True,
    )
    save_to_gcfs(feeds.to_csv(index=False).encode(),
                 f"{path_metadata}/feeds.csv",
                 use_pipe=True)

    # note that right now we useairflow's xcom functionality in this dag.
    # because xcom can only store a small amount of data, we have to drop some
    # columns. this is the only dag that uses xcom, and we should remove it!
    df_subset = feeds.drop(columns=[
        "gtfs_rt_vehicle_positions_url",
        "gtfs_rt_service_alerts_url",
        "gtfs_rt_trip_updates_url",
    ])

    return df_subset.to_dict("records")
示例#5
0
    def __init__(
        self,
        *args,
        bucket=None,
        prefix_bucket=False,
        destination_project_dataset_table=None,  # note that the project is optional here
        skip_leading_rows=1,
        schema_fields=None,
        hive_options=None,
        source_objects=[],
        source_format="CSV",
        geojson=False,
        use_bq_client=False,
        field_delimiter=",",
        post_hook=None,
        **kwargs,
    ):
        self.bucket = bucket
        # This only exists because the prefix_bucket() template isn't working in the yml file for some reason
        if self.bucket and prefix_bucket and is_development():
            self.bucket = re.sub(r"gs://([\w-]+)", r"gs://test-\1",
                                 self.bucket)

        self.destination_project_dataset_table = format_table_name(
            destination_project_dataset_table)
        self.skip_leading_rows = skip_leading_rows
        self.schema_fields = schema_fields
        self.source_objects = source_objects
        self.source_format = source_format
        self.geojson = geojson
        self.hive_options = hive_options
        self.use_bq_client = use_bq_client
        self.field_delimiter = field_delimiter
        self.post_hook = post_hook

        super().__init__(**kwargs)
示例#6
0
def prefix_bucket(bucket):
    # TODO: use once we're in python 3.9+
    # bucket = bucket.removeprefix("gs://")
    bucket = bucket.replace("gs://", "")
    return f"gs://test-{bucket}" if is_development() else f"gs://{bucket}"
示例#7
0
def is_development_macro():
    """Make calitp-py's is_development function available via macro"""

    return is_development()
示例#8
0
    FROM
        `airtable.california_transit_{table}` T2
)

SELECT *
FROM unnested_t1
LEFT JOIN t2 USING({id2})
"""

# ACTUALLY DEFINE MACROS =============================================================

# template must be added here to be accessed in dags.py
# key is alias that will be used to reference the template in DAG tasks
# value is name of function template as defined above


def prefix_bucket(bucket):
    # TODO: use once we're in python 3.9+
    # bucket = bucket.removeprefix("gs://")
    bucket = bucket.replace("gs://", "")
    return f"gs://test-{bucket}" if is_development() else f"gs://{bucket}"


data_infra_macros = {
    "sql_airtable_mapping": airtable_mapping_generate_sql,
    "is_development": is_development_macro,
    "image_tag": lambda: "development" if is_development() else "latest",
    "env_var": lambda key: os.getenv(key),
    "prefix_bucket": prefix_bucket,
}
示例#9
0
def download_all(task_instance, execution_date, **kwargs):
    start = pendulum.now()
    # https://stackoverflow.com/a/61808755
    with create_session() as session:
        auth_dict = {var.key: var.val for var in session.query(Variable)}

    records = [
        record for record in AirtableGTFSDataExtract.get_latest().records if
        record.data_quality_pipeline and record.data == GTFSFeedType.schedule
    ]
    outcomes: List[AirtableGTFSDataRecordProcessingOutcome] = []

    logging.info(f"processing {len(records)} records")

    for i, record in enumerate(records, start=1):
        logging.info(f"attempting to fetch {i}/{len(records)} {record.uri}")

        try:
            # this is a bit hacky but we need this until we split off auth query params from the URI itself
            jinja_pattern = r"(?P<param_name>\w+)={{\s*(?P<param_lookup_key>\w+)\s*}}"
            match = re.search(jinja_pattern, record.uri)
            if match:
                record.auth_query_param = {
                    match.group("param_name"): match.group("param_lookup_key")
                }
                record.uri = re.sub(jinja_pattern, "", record.uri)

            extract, content = download_feed(
                record,
                auth_dict=auth_dict,
                ts=start,
            )

            extract.save_content(fs=get_fs(), content=content)

            outcomes.append(
                AirtableGTFSDataRecordProcessingOutcome(
                    success=True,
                    airtable_record=record,
                    extract=extract,
                ))
        except Exception as e:
            logging.error(
                f"exception occurred while attempting to download feed {record.uri}: {str(e)}\n{traceback.format_exc()}"
            )
            outcomes.append(
                AirtableGTFSDataRecordProcessingOutcome(
                    success=False,
                    exception=e,
                    airtable_record=record,
                ))

    # TODO: save the outcomes somewhere

    print(
        f"took {humanize.naturaltime(pendulum.now() - start)} to process {len(records)} records"
    )

    assert len(records) == len(
        outcomes
    ), f"we somehow ended up with {len(outcomes)} outcomes from {len(records)} records"

    result = DownloadFeedsResult(
        ts=start,
        end=pendulum.now(),
        outcomes=outcomes,
        filename="results.jsonl",
    )

    result.save(get_fs())

    print(f"successfully fetched {len(result.successes)} of {len(records)}")

    if result.failures:
        print(
            "Failures:\n",
            "\n".join(
                str(f.exception) or str(type(f.exception))
                for f in result.failures),
        )
        # use pandas begrudgingly for email HTML since the old task used it
        html_report = pd.DataFrame(
            f.dict() for f in result.failures).to_html(border=False)

        html_content = f"""\
    NOTE: These failures come from the v2 of the GTFS Schedule downloader.

    The following agency GTFS feeds could not be extracted on {start.to_iso8601_string()}:

    {html_report}
    """
    else:
        html_content = "All feeds were downloaded successfully!"

    if is_development():
        print(
            f"Skipping since in development mode! Would have emailed {len(result.failures)} failures."
        )
    else:
        send_email(
            to=[
                "*****@*****.**",
                "*****@*****.**",
            ],
            html_content=html_content,
            subject=
            (f"Operator GTFS Errors for {datetime.datetime.now().strftime('%Y-%m-%d')}"
             ),
        )

    success_rate = len(result.successes) / len(records)
    if success_rate < GTFS_FEED_LIST_ERROR_THRESHOLD:
        raise RuntimeError(
            f"Success rate: {success_rate:.3f} was below error threshold: {GTFS_FEED_LIST_ERROR_THRESHOLD}"
        )