def main(date, bucket, prefix, num_clusters, num_donors, kernel_bandwidth, num_pdf_points, clients_sample_date_from): logger.info("Sampling clients since {}".format(clients_sample_date_from)) spark = (SparkSession.builder.appName( "taar_similarity").enableHiveSupport().getOrCreate()) if num_donors < 100: logger.warn("Less than 100 donors were requested.", extra={"donors": num_donors}) num_donors = 100 logger.info("Loading the AMO whitelist...") whitelist = load_amo_curated_whitelist() logger.info("Computing the list of donors...") # Compute the donors clusters and the LR curves. cluster_ids, donors_df = get_donors(spark, num_clusters, num_donors, whitelist, clients_sample_date_from) donors_df.cache() lr_curves = get_lr_curves(spark, donors_df, cluster_ids, kernel_bandwidth, num_pdf_points) # Store them. donors = format_donors_dictionary(donors_df) store_json_to_s3(json.dumps(donors, indent=2), "donors", date, prefix, bucket) store_json_to_s3(json.dumps(lr_curves, indent=2), "lr_curves", date, prefix, bucket) stop_session_safely(spark)
def generate_rollups( submission_date, output_bucket, output_prefix, output_version, transform_func, input_bucket=DEFAULT_INPUT_BUCKET, input_prefix=DEFAULT_INPUT_PREFIX, save_mode=DEFAULT_SAVE_MODE, orderBy=[], ): """Load main_summary, apply transform_func, and write to S3""" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info("Running the {0} ETL job...".format(transform_func.__name__)) start = datetime.datetime.now() spark = SparkSession.builder.appName("search_dashboard_etl").getOrCreate() source_path = "s3://{}/{}/submission_date_s3={}".format( input_bucket, input_prefix, submission_date ) output_path = "s3://{}/{}/v{}/submission_date_s3={}".format( output_bucket, output_prefix, output_version, submission_date ) logger.info("Loading main_summary...") main_summary = spark.read.parquet(source_path) logger.info("Applying transformation function...") search_dashboard_data = transform_func(main_summary) if orderBy: search_dashboard_data = search_dashboard_data.orderBy(*orderBy) logger.info("Saving rollups to: {}".format(output_path)) (search_dashboard_data.write.mode(save_mode).save(output_path)) stop_session_safely(spark) logger.info("... done (took: %s)", str(datetime.datetime.now() - start))
def main(local, submission_date_s3, input_bucket, input_prefix, output_bucket, output_prefix): # print argument information for k, v in locals().items(): print("{}: {}".format(k, v)) print("Python version: {}".format(sys.version_info)) spark = SparkSession.builder.getOrCreate() print("Spark version: {}".format(spark.version)) # run a basic count over a sample of `main_summary` from 2 days ago if not local: ds_nodash = submission_date_s3 input_path = format_spark_path(input_bucket, input_prefix) output_path = format_spark_path(output_bucket, output_prefix) print( "Reading data for {ds_nodash} from {input_path} and writing to {output_path}" .format(ds_nodash=ds_nodash, input_path=input_path, output_path=output_path)) path = "{}/submission_date_s3={}/sample_id={}".format( input_path, ds_nodash, 1) subset = spark.read.parquet(path) print("Saw {} documents".format(subset.count())) summary = subset.select("memory_mb", "cpu_cores", "subsession_length").describe() summary.show() summary.write.parquet(output_path + "/submission_date_s3={}/".format(ds_nodash), mode="overwrite") stop_session_safely(spark) print("Done!")
def test_stop_session_safely_databricks(): spark_session = Mock(conf={"spark.home": "/databricks/spark"}) utils.stop_session_safely(spark_session) spark_session.stop.assert_not_called()
def test_stop_session_safely_emr(): spark_session = Mock(conf={"spark.home": "/var/lib/spark/"}) utils.stop_session_safely(spark_session) spark_session.stop.assert_called_once()