Exemple #1
0
def main(date, bucket, prefix, num_clusters, num_donors, kernel_bandwidth,
         num_pdf_points, clients_sample_date_from):
    logger.info("Sampling clients since {}".format(clients_sample_date_from))

    spark = (SparkSession.builder.appName(
        "taar_similarity").enableHiveSupport().getOrCreate())

    if num_donors < 100:
        logger.warn("Less than 100 donors were requested.",
                    extra={"donors": num_donors})
        num_donors = 100

    logger.info("Loading the AMO whitelist...")
    whitelist = load_amo_curated_whitelist()

    logger.info("Computing the list of donors...")

    # Compute the donors clusters and the LR curves.
    cluster_ids, donors_df = get_donors(spark, num_clusters, num_donors,
                                        whitelist, clients_sample_date_from)
    donors_df.cache()
    lr_curves = get_lr_curves(spark, donors_df, cluster_ids, kernel_bandwidth,
                              num_pdf_points)

    # Store them.
    donors = format_donors_dictionary(donors_df)
    store_json_to_s3(json.dumps(donors, indent=2), "donors", date, prefix,
                     bucket)
    store_json_to_s3(json.dumps(lr_curves, indent=2), "lr_curves", date,
                     prefix, bucket)
    stop_session_safely(spark)
Exemple #2
0
def generate_rollups(
    submission_date,
    output_bucket,
    output_prefix,
    output_version,
    transform_func,
    input_bucket=DEFAULT_INPUT_BUCKET,
    input_prefix=DEFAULT_INPUT_PREFIX,
    save_mode=DEFAULT_SAVE_MODE,
    orderBy=[],
):
    """Load main_summary, apply transform_func, and write to S3"""
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    logger.info("Running the {0} ETL job...".format(transform_func.__name__))
    start = datetime.datetime.now()
    spark = SparkSession.builder.appName("search_dashboard_etl").getOrCreate()

    source_path = "s3://{}/{}/submission_date_s3={}".format(
        input_bucket, input_prefix, submission_date
    )
    output_path = "s3://{}/{}/v{}/submission_date_s3={}".format(
        output_bucket, output_prefix, output_version, submission_date
    )

    logger.info("Loading main_summary...")
    main_summary = spark.read.parquet(source_path)

    logger.info("Applying transformation function...")
    search_dashboard_data = transform_func(main_summary)

    if orderBy:
        search_dashboard_data = search_dashboard_data.orderBy(*orderBy)

    logger.info("Saving rollups to: {}".format(output_path))
    (search_dashboard_data.write.mode(save_mode).save(output_path))

    stop_session_safely(spark)
    logger.info("... done (took: %s)", str(datetime.datetime.now() - start))
def main(local, submission_date_s3, input_bucket, input_prefix, output_bucket,
         output_prefix):
    # print argument information
    for k, v in locals().items():
        print("{}: {}".format(k, v))

    print("Python version: {}".format(sys.version_info))
    spark = SparkSession.builder.getOrCreate()
    print("Spark version: {}".format(spark.version))

    # run a basic count over a sample of `main_summary` from 2 days ago
    if not local:
        ds_nodash = submission_date_s3
        input_path = format_spark_path(input_bucket, input_prefix)
        output_path = format_spark_path(output_bucket, output_prefix)

        print(
            "Reading data for {ds_nodash} from {input_path} and writing to {output_path}"
            .format(ds_nodash=ds_nodash,
                    input_path=input_path,
                    output_path=output_path))

        path = "{}/submission_date_s3={}/sample_id={}".format(
            input_path, ds_nodash, 1)
        subset = spark.read.parquet(path)
        print("Saw {} documents".format(subset.count()))

        summary = subset.select("memory_mb", "cpu_cores",
                                "subsession_length").describe()
        summary.show()

        summary.write.parquet(output_path +
                              "/submission_date_s3={}/".format(ds_nodash),
                              mode="overwrite")

    stop_session_safely(spark)
    print("Done!")
Exemple #4
0
def test_stop_session_safely_databricks():
    spark_session = Mock(conf={"spark.home": "/databricks/spark"})

    utils.stop_session_safely(spark_session)

    spark_session.stop.assert_not_called()
Exemple #5
0
def test_stop_session_safely_emr():
    spark_session = Mock(conf={"spark.home": "/var/lib/spark/"})

    utils.stop_session_safely(spark_session)

    spark_session.stop.assert_called_once()