def test_run_builtin_aggregators_success(spark, ctx_obj, get_context): ctx_obj["aggregators"] = { "cortex.sum": { "name": "sum", "namespace": "cortex" }, "cortex.first": { "name": "first", "namespace": "cortex" }, } ctx_obj["aggregates"] = { "sum_a": { "name": "sum_a", "id": "1", "aggregator": "cortex.sum", "inputs": { "features": { "col": "a" } }, }, "first_a": { "id": "2", "name": "first_a", "aggregator": "cortex.first", "inputs": { "features": { "col": "a" }, "args": { "ignorenulls": "some_constant" } }, }, } aggregate_list = [v for v in ctx_obj["aggregates"].values()] ctx = get_context(ctx_obj) ctx.store_aggregate_result = MagicMock() ctx.populate_args = MagicMock(return_value={"ignorenulls": True}) data = [Row(a=None), Row(a=1), Row(a=2), Row(a=3)] df = spark.createDataFrame(data, StructType([StructField("a", LongType())])) spark_util.run_builtin_aggregators(aggregate_list, df, ctx, spark) calls = [ call(6, ctx_obj["aggregates"]["sum_a"]), call(1, ctx_obj["aggregates"]["first_a"]), ] ctx.store_aggregate_result.assert_has_calls(calls, any_order=True) ctx.populate_args.assert_called_once_with({"ignorenulls": "some_constant"})
def run_custom_aggregators(spark, ctx, cols_to_aggregate, raw_df): logger.info("Aggregating") results = {} aggregate_names = [ctx.ag_id_map[f]["name"] for f in cols_to_aggregate] builtin_aggregates, custom_aggregates = spark_util.split_aggregators( sorted(aggregate_names), ctx ) if len(builtin_aggregates) > 0: ctx.upload_resource_status_start(*builtin_aggregates) try: for aggregate in builtin_aggregates: logger.info("Aggregating " + ", ".join(ctx.ag_id_map[aggregate["id"]]["aliases"])) results = spark_util.run_builtin_aggregators(builtin_aggregates, raw_df, ctx, spark) except: ctx.upload_resource_status_failed(*builtin_aggregates) raise ctx.upload_resource_status_success(*builtin_aggregates) for aggregate in custom_aggregates: ctx.upload_resource_status_start(aggregate) try: logger.info("Aggregating " + ", ".join(ctx.ag_id_map[aggregate["id"]]["aliases"])) result = spark_util.run_custom_aggregator(aggregate, raw_df, ctx, spark) results[aggregate["name"]] = result except: ctx.upload_resource_status_failed(aggregate) raise ctx.upload_resource_status_success(aggregate) show_aggregates(ctx, results)
def test_run_builtin_aggregators_success(spark, ctx_obj, get_context): ctx_obj["raw_columns"] = { "a": { "id": "2", "name": "a", "type": "INT_COLUMN" } } ctx_obj["aggregators"] = { "cortex.sum_int": { "name": "sum_int", "namespace": "cortex", "input": { "_type": "INT_COLUMN" }, "output_type": "INT_COLUMN", } } ctx_obj["aggregates"] = { "sum_a": { "name": "sum_a", "id": "1", "aggregator": "cortex.sum_int", "input": add_res_ref("a"), } } aggregate_list = [v for v in ctx_obj["aggregates"].values()] ctx = get_context(ctx_obj) ctx.store_aggregate_result = MagicMock() ctx.populate_args = MagicMock(return_value={"ignorenulls": True}) data = [Row(a=None), Row(a=1), Row(a=2), Row(a=3)] df = spark.createDataFrame(data, StructType([StructField("a", LongType())])) spark_util.run_builtin_aggregators(aggregate_list, df, ctx, spark) calls = [call(6, ctx_obj["aggregates"]["sum_a"])] ctx.store_aggregate_result.assert_has_calls(calls, any_order=True)
def test_run_builtin_aggregators_error(spark, ctx_obj, get_context): ctx_obj["aggregators"] = { "cortex.first": { "name": "first", "namespace": "cortex" } } ctx_obj["aggregates"] = { "first_a": { "name": "first_a", "aggregator": "cortex.first", "inputs": { "columns": { "col": "a" }, "args": { "ignoreNulls": "some_constant" }, # supposed to be ignorenulls }, "id": "1", } } aggregate_list = [v for v in ctx_obj["aggregates"].values()] ctx = get_context(ctx_obj) ctx.store_aggregate_result = MagicMock() ctx.populate_args = MagicMock(return_value={"ignoreNulls": True}) data = [Row(a=None), Row(a=1), Row(a=2), Row(a=3)] df = spark.createDataFrame(data, StructType([StructField("a", LongType())])) with pytest.raises(Exception) as exec_info: spark_util.run_builtin_aggregators(aggregate_list, df, ctx, spark) ctx.store_aggregate_result.assert_not_called() ctx.populate_args.assert_called_once_with({"ignoreNulls": "some_constant"})