def test_get_per_client_data_doesnt_crash(spark): exp = Experiment('a-stub', '20190101', 8) enrollments = exp.get_enrollments(spark, _get_enrollment_view(slug="a-stub")) metrics = _get_metrics(spark) metric__how_many_ones = metrics['how_many_ones'] exp.get_per_client_data(enrollments, [metric__how_many_ones], '20190114', 0, 3)
def test_get_per_client_data_join(spark): exp = Experiment('a-stub', '20190101') enrollments = spark.createDataFrame( [ ['aaaa', 'control', '20190101'], ['bbbb', 'test', '20190101'], ['cccc', 'control', '20190108'], ['dddd', 'test', '20190109'], ['annie-nodata', 'control', '20190101'], ['bob-badtiming', 'test', '20190102'], ['carol-gooddata', 'test', '20190101'], ['derek-lateisok', 'control', '20190110'], ], [ "client_id", "branch", "enrollment_date", ], ) ex_d = {'a-stub': 'fake-branch-lifes-too-short'} data_source_df = spark.createDataFrame( [ # bob-badtiming only has data before/after analysis window # but missed by `process_data_source` ['bob-badtiming', '20190102', ex_d, 1], ['bob-badtiming', '20190106', ex_d, 2], # carol-gooddata has data on two days (including a dupe day) ['carol-gooddata', '20190102', ex_d, 3], ['carol-gooddata', '20190102', ex_d, 2], ['carol-gooddata', '20190104', ex_d, 6], # derek-lateisok has data before and during the analysis window ['derek-lateisok', '20190110', ex_d, 1000], ['derek-lateisok', '20190111', ex_d, 1], # TODO: exercise the last condition on the join ], [ "client_id", "submission_date_s3", "experiments", "some_value", ], ) ds = DataSource.from_dataframe('ds', data_source_df) metric = Metric.from_col('some_value', agg_sum(data_source_df.some_value), ds) res = exp.get_per_client_data(enrollments, [metric], '20190114', 1, 3, keep_client_id=True) # Check that the dataframe has the correct number of rows assert res.count() == enrollments.count() # Check that dataless enrollments are handled correctly annie_nodata = res.filter(res.client_id == 'annie-nodata') assert annie_nodata.count() == 1 assert annie_nodata.first()['some_value'] == 0 # Check that early and late data were ignored # i.e. check the join, not just _process_data_source_df bob_badtiming = res.filter(res.client_id == 'bob-badtiming') assert bob_badtiming.count() == 1 assert bob_badtiming.first()['some_value'] == 0 # Check that _process_data_source_df didn't do the # heavy lifting above time_limits = TimeLimits.for_single_analysis_window( exp.start_date, '20190114', 1, 3, exp.num_dates_enrollment) pds = exp._process_data_source_df(data_source_df, time_limits) assert pds.filter(pds.client_id == 'bob-badtiming').select( F.sum(pds.some_value).alias('agg_val')).first()['agg_val'] == 3 # Check that relevant data was included appropriately carol_gooddata = res.filter(res.client_id == 'carol-gooddata') assert carol_gooddata.count() == 1 assert carol_gooddata.first()['some_value'] == 11 derek_lateisok = res.filter(res.client_id == 'derek-lateisok') assert derek_lateisok.count() == 1 assert derek_lateisok.first()['some_value'] == 1 # Check that it still works for `data_source`s without an experiments map ds_df_noexp = data_source_df.drop('experiments') ds_noexp = DataSource.from_dataframe('ds_noexp', ds_df_noexp) metric_noexp = Metric.from_col('some_value', agg_sum(ds_df_noexp.some_value), ds_noexp) res2 = exp.get_per_client_data(enrollments, [metric_noexp], '20190114', 1, 3, keep_client_id=True) assert res2.count() == enrollments.count()