Beispiel #1
0
def test_get_time_series_data_lazy_daily(spark):
    exp = Experiment('a-stub', '20190101', 8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))
    metrics = _get_metrics(spark)
    metric__how_many_ones = metrics['how_many_ones']

    res = exp.get_time_series_data_lazy(
        enrollments,
        [metric__how_many_ones],
        '20190114',
        time_series_period='daily',
        keep_client_id=True,
    )

    assert len(res) == 7

    for df in res.values():
        pdf = df.toPandas()
        assert pdf.client_id.nunique() == 3
        assert len(pdf) == 3

        pdf = pdf.set_index('client_id')

        assert pdf.loc['aaaa', 'how_many_ones'] == 1
        assert pdf.loc['bbbb', 'how_many_ones'] == 1
        assert pdf.loc['cccc', 'how_many_ones'] == 0
        assert (pdf['bla_ds_has_contradictory_branch'] == 0).all()
        assert (pdf['bla_ds_has_non_enrolled_data'] == 0).all()
Beispiel #2
0
def test_exposure_signal_query_custom_windows():
    exp = Experiment("slug", "2019-01-01", 8, app_id="my_cool_app")

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    enrollment_sql = exp.build_enrollments_query(
        time_limits=tl,
        enrollments_query_type="glean-event",
        exposure_signal=ExposureSignal(
            name="exposures",
            data_source=mozanalysis.metrics.fenix.baseline,
            select_expr="metrics.counter.events_total_uri_count > 0",
            friendly_name="URI visited exposure",
            description="Exposed when URI visited",
            window_start=1,
            window_end=3,
        ),
    )

    sql_lint(enrollment_sql)

    assert "exposures" in enrollment_sql
    assert "metrics.counter.events_total_uri_count > 0" in enrollment_sql
    assert "DATE_ADD('2019-01-01', INTERVAL 1 DAY)" in enrollment_sql
    assert "DATE_ADD('2019-01-01', INTERVAL 3 DAY)" in enrollment_sql
Beispiel #3
0
def test_process_enrollments(spark):
    exp = Experiment('a-stub', '20190101')
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))
    assert enrollments.count() == 4

    # With final data collected on '20190114', we have 7 dates of data
    # for 'cccc' enrolled on '20190108' but not for 'dddd' enrolled on
    # '20190109'.
    tl = TimeLimits.for_single_analysis_window(
        first_enrollment_date=exp.start_date,
        last_date_full_data='20190114',
        analysis_start_days=0,
        analysis_length_dates=7,
        num_dates_enrollment=exp.num_dates_enrollment)
    assert tl.last_enrollment_date == '20190108'
    assert len(tl.analysis_windows) == 1
    assert tl.analysis_windows[0].end == 6

    pe = exp._process_enrollments(enrollments, tl)
    assert pe.count() == 3

    pe = exp._process_enrollments(enrollments.alias('main_summary'), tl)
    assert pe.select(F.col('enrollments.enrollment_date'))
    with pytest.raises(AnalysisException):
        assert pe.select(F.col('main_summary.enrollment_date'))
Beispiel #4
0
def test_metrics_query_based_on_exposure():
    exp = Experiment("slug", "2019-01-01", 8)

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    enrollments_sql = exp.build_enrollments_query(
        time_limits=tl, enrollments_query_type="fenix-fallback")

    sql_lint(enrollments_sql)

    metrics_sql = exp.build_metrics_query(
        metric_list=[
            m for m in mozanalysis.metrics.fenix.__dict__.values()
            if isinstance(m, Metric)
        ],
        time_limits=tl,
        enrollments_table="enrollments",
        analysis_basis=AnalysisBasis.EXPOSURES,
    )

    sql_lint(metrics_sql)

    assert "e.exposure_date" in metrics_sql
Beispiel #5
0
def test_add_analysis_windows_to_enrollments(spark):
    exp = Experiment('a-stub', '20190101', num_dates_enrollment=8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))
    assert enrollments.count() == 3

    tl = TimeLimits.for_ts(
        first_enrollment_date=exp.start_date,
        last_date_full_data='20190114',
        time_series_period='daily',
        num_dates_enrollment=exp.num_dates_enrollment,
    )
    assert len(tl.analysis_windows) == 7

    new_enrollments = exp._add_analysis_windows_to_enrollments(enrollments, tl)

    nep = new_enrollments.toPandas()
    assert len(nep) == enrollments.count() * len(tl.analysis_windows)

    a = nep[nep['client_id'] == 'aaaa']
    assert len(a) == len(tl.analysis_windows)
    assert (a.mozanalysis_analysis_window_start.sort_values() == np.arange(
        len(tl.analysis_windows))).all()
    assert (a.mozanalysis_analysis_window_end.sort_values() == np.arange(
        len(tl.analysis_windows))).all()
Beispiel #6
0
def test_no_analysis_exception_when_shared_parent_dataframe(spark):
    # Check that we don't fall victim to
    # https://issues.apache.org/jira/browse/SPARK-10925
    df = spark.createDataFrame(
        [  # Just need the schema, really
            ['someone', '20190102', 'fake', 1],
        ],
        [
            "client_id",
            "submission_date_s3",
            "branch",
            "some_value",
        ])

    enrollments = df.groupby('client_id', 'branch').agg(
        F.min('submission_date_s3').alias('enrollment_date'))

    exp = Experiment('a-stub', '20180101')

    time_limits = TimeLimits.for_single_analysis_window(
        exp.start_date,
        last_date_full_data='20190522',
        analysis_start_days=28,
        analysis_length_dates=7)

    enrollments = exp._add_analysis_windows_to_enrollments(
        enrollments, time_limits)

    exp._get_results_for_one_data_source(
        enrollments,
        df,
        [F.max(F.col('some_value'))],
    )
Beispiel #7
0
def test_segments_megaquery_not_detectably_malformed():
    exp = Experiment("slug", "2019-01-01", 8)

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    enrollments_sql = exp.build_enrollments_query(
        time_limits=tl,
        segment_list=[s for s in msd.__dict__.values() if isinstance(s, msd.Segment)],
        enrollments_query_type="normandy",
    )

    sql_lint(enrollments_sql)

    metrics_sql = exp.build_metrics_query(
        metric_list=[m for m in mad.__dict__.values() if isinstance(m, mad.Metric)],
        time_limits=tl,
        enrollments_table="enrollments",
    )

    sql_lint(metrics_sql)
Beispiel #8
0
def test_get_per_client_data_doesnt_crash(spark):
    exp = Experiment('a-stub', '20190101', 8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))
    metrics = _get_metrics(spark)
    metric__how_many_ones = metrics['how_many_ones']

    exp.get_per_client_data(enrollments, [metric__how_many_ones], '20190114',
                            0, 3)
Beispiel #9
0
def test_get_enrollments_debug_dupes(spark):
    exp = Experiment('a-stub', '20190101')
    view_method = _get_enrollment_view("a-stub")

    enrl = exp.get_enrollments(spark, view_method)
    assert 'num_events' not in enrl.columns

    enrl2 = exp.get_enrollments(spark, view_method, debug_dupes=True)
    assert 'num_events' in enrl2.columns

    penrl2 = enrl2.toPandas()
    assert (penrl2['num_events'] == 1).all()
Beispiel #10
0
def test_process_data_source_df(spark):
    start_date = '20190101'
    exp_8d = Experiment('experiment-with-8-day-cohort', start_date, 8)
    data_source_df = _get_data_source_df(spark)

    end_date = '20190114'

    # Are the fixtures sufficiently complicated that we're actually testing
    # things?
    assert _simple_return_agg_date(F.min, data_source_df) < start_date
    assert _simple_return_agg_date(F.max, data_source_df) > end_date

    tl_03 = TimeLimits.for_single_analysis_window(
        first_enrollment_date=exp_8d.start_date,
        last_date_full_data=end_date,
        analysis_start_days=0,
        analysis_length_dates=3,
        num_dates_enrollment=exp_8d.num_dates_enrollment)
    assert tl_03.first_date_data_required == start_date
    assert tl_03.last_date_data_required == '20190110'

    proc_ds = exp_8d._process_data_source_df(data_source_df, tl_03)

    assert _simple_return_agg_date(F.min,
                                   proc_ds) == tl_03.first_date_data_required
    assert _simple_return_agg_date(F.max,
                                   proc_ds) == tl_03.last_date_data_required

    tl_23 = TimeLimits.for_single_analysis_window(
        first_enrollment_date=exp_8d.start_date,
        last_date_full_data=end_date,
        analysis_start_days=2,
        analysis_length_dates=3,
        num_dates_enrollment=exp_8d.num_dates_enrollment)
    assert tl_23.first_date_data_required == add_days(start_date, 2)
    assert tl_23.last_date_data_required == '20190112'

    p_ds_2 = exp_8d._process_data_source_df(data_source_df, tl_23)

    assert _simple_return_agg_date(F.min,
                                   p_ds_2) == tl_23.first_date_data_required
    assert _simple_return_agg_date(F.max,
                                   p_ds_2) == tl_23.last_date_data_required

    assert proc_ds.select(F.col('data_source.client_id'))
    with pytest.raises(AnalysisException):
        assert data_source_df.select(F.col('data_source.client_id'))
Beispiel #11
0
def test_megaquery_not_detectably_malformed():
    exp = Experiment('slug', '2019-01-01', 8)

    tl = TimeLimits.for_ts(first_enrollment_date='2019-01-01',
                           last_date_full_data='2019-03-01',
                           time_series_period='weekly',
                           num_dates_enrollment=8)

    sql = exp.build_query(
        metric_list=[
            m for m in mad.__dict__.values() if isinstance(m, mad.Metric)
        ],
        time_limits=tl,
        enrollments_query_type='normandy',
    )

    sql_lint(sql)
Beispiel #12
0
def test_query_not_detectably_malformed():
    exp = Experiment('slug', '2019-01-01', 8)

    tl = TimeLimits.for_ts(
        first_enrollment_date='2019-01-01',
        last_date_full_data='2019-03-01',
        time_series_period='weekly',
        num_dates_enrollment=8
    )

    sql = exp.build_query(
        metric_list=[],
        time_limits=tl,
        enrollments_query_type='normandy',
    )

    sql_lint(sql)
Beispiel #13
0
def test_firefox_ios_klar_app_id_propagation():
    exp = Experiment("slug", "2019-01-01", 8, app_id="my_cool_app")

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    sds = SegmentDataSource(
        name="cool_data_source",
        from_expr="`moz-fx-data-shared-prod`.{dataset}.cool_table",
        default_dataset="org_mozilla_ios_klar",
    )

    segment = Segment(
        name="cool_segment",
        select_expr="COUNT(*)",
        data_source=sds,
    )

    enrollments_sql = exp.build_enrollments_query(
        time_limits=tl,
        segment_list=[segment],
        enrollments_query_type="glean-event",
    )

    sql_lint(enrollments_sql)

    metrics_sql = exp.build_metrics_query(
        metric_list=[
            m for m in mozanalysis.metrics.klar_ios.__dict__.values()
            if isinstance(m, Metric)
        ],
        time_limits=tl,
        enrollments_table="enrollments",
    )

    sql_lint(metrics_sql)

    assert "org_mozilla_ios_klar" not in enrollments_sql
    assert "my_cool_app" in enrollments_sql

    sql_lint(metrics_sql)
Beispiel #14
0
def test_exposure_query():
    exp = Experiment("slug", "2019-01-01", 8, app_id="my_cool_app")

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    enrollment_sql = exp.build_enrollments_query(
        time_limits=tl,
        enrollments_query_type="glean-event",
    )

    sql_lint(enrollment_sql)

    assert "exposures" in enrollment_sql
Beispiel #15
0
def test_query_not_detectably_malformed():
    exp = Experiment('slug', '2019-01-01', 8)

    tl = TimeLimits.for_ts(first_enrollment_date='2019-01-01',
                           last_date_full_data='2019-03-01',
                           time_series_period='weekly',
                           num_dates_enrollment=8)

    sql = exp.build_query(
        metric_list=[],
        time_limits=tl,
        enrollments_query_type='normandy',
    )

    # This query is actually slightly malformed, due to a trailing comma.
    # We should add a metric here if the linter ever improves.

    sql_lint(sql)
Beispiel #16
0
def dry_run_query(exp_path):
    report = validate_schema(op.join(exp_path, "report.json"))
    metric_list = _make_metric_list(report)

    exp = Experiment(experiment_slug=report["experiment_slug"],
                     start_date=report["start_date"],
                     num_dates_enrollment=report["num_dates_enrollment"])
    # create an archive of the sql generating analysis
    time_limits = TimeLimits.for_single_analysis_window(
        first_enrollment_date=report['start_date'],
        last_date_full_data=report['last_date_full_data'],
        analysis_start_days=report['analysis_start_days'],
        analysis_length_dates=report['analysis_length_days'],
        num_dates_enrollment=report['num_dates_enrollment'])
    query = exp.build_query(metric_list=metric_list,
                            time_limits=time_limits,
                            enrollments_query_type='normandy')

    return query
Beispiel #17
0
def aggregate_data(exp_path):
    report = validate_schema(op.join(op.abspath(exp_path), "report.json"))
    exp = Experiment(experiment_slug=report["experiment_slug"],
                     start_date=report["start_date"],
                     num_dates_enrollment=report["num_dates_enrollment"])

    bq_context = BigQueryContext(dataset_id=report["dataset_id"])
    metric_list = _make_metric_list(report)

    single_window_res = exp.get_single_window_data(
        bq_context=bq_context,
        metric_list=metric_list,
        last_date_full_data=report["last_date_full_data"],
        analysis_start_days=report["analysis_start_days"],
        analysis_length_days=report["analysis_length_days"])
    # TODO: Figure out another way to deal with missing values per client
    single_window_res.dropna(inplace=True)

    return single_window_res
Beispiel #18
0
def test_get_time_series_data(spark):
    exp = Experiment('a-stub', '20190101', 8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))
    metrics = _get_metrics(spark)
    metric__how_many_ones = metrics['how_many_ones']

    res = exp.get_time_series_data(
        enrollments,
        [metric__how_many_ones],
        '20190128',
        time_series_period='weekly',
        keep_client_id=True,
    )

    assert len(res) == 3
    df = res[0]
    assert df.client_id.nunique() == 3
    assert len(df) == 3

    df = df.set_index('client_id')
    print(df.columns)

    assert df.loc['aaaa', 'how_many_ones'] == 7
    assert df.loc['bbbb', 'how_many_ones'] == 7
    assert df.loc['cccc', 'how_many_ones'] == 0
    assert (df['bla_ds_has_contradictory_branch'] == 0).all()
    assert (df['bla_ds_has_non_enrolled_data'] == 0).all()

    df = res[14]
    assert df.client_id.nunique() == 3
    assert len(df) == 3

    df = df.set_index('client_id')

    assert df.loc['aaaa', 'how_many_ones'] == 1
    assert df.loc['bbbb', 'how_many_ones'] == 1
    assert df.loc['cccc', 'how_many_ones'] == 0
    assert (df['bla_ds_has_contradictory_branch'] == 0).all()
    assert (df['bla_ds_has_non_enrolled_data'] == 0).all()
Beispiel #19
0
def test_process_metrics(spark):
    exp = Experiment('a-stub', '20190101', num_dates_enrollment=8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))

    ds_df_A = register_data_source_fixture(spark, name='ds_df_A')
    ds_df_B = register_data_source_fixture(spark, name='ds_df_B')

    ds_A = DataSource.from_dataframe('ds_df_A', ds_df_A)
    ds_B = DataSource.from_dataframe('ds_df_B', ds_df_B)

    m1 = Metric.from_col('m1', ds_df_A.numeric_col, ds_A)
    m2 = Metric.from_col('m2', ds_df_A.bool_col, ds_A)
    m3 = Metric.from_col('m3', ds_df_B.numeric_col, ds_B)

    metric_list = [m1, m2, m3]

    exp = Experiment('a-stub', '20190101')

    data_sources_and_metrics = exp._process_metrics(enrollments, metric_list)

    assert len(data_sources_and_metrics) == 2

    assert len(data_sources_and_metrics[ds_df_A]) == 2
    assert len(data_sources_and_metrics[ds_df_B]) == 1

    assert 'numeric_col' in repr(data_sources_and_metrics[ds_df_B][0])
    assert '`m3`' in repr(data_sources_and_metrics[ds_df_B][0])
    assert repr(data_sources_and_metrics[ds_df_B][0]) in {
        "Column<b'numeric_col AS `m3`'>",  # py3
        "Column<numeric_col AS `m3`>",  # py2
    }
Beispiel #20
0
def test_get_enrollments(spark):
    exp = Experiment('a-stub', '20190101')
    view_method = _get_enrollment_view("a-stub")
    assert exp.get_enrollments(spark, view_method).count() == 4

    exp2 = Experiment('a-stub2', '20190102')
    view_method2 = _get_enrollment_view("a-stub2")
    enrl2 = exp2.get_enrollments(spark, study_type=view_method2)
    assert enrl2.count() == 2
    assert enrl2.select(F.min(
        enrl2.enrollment_date).alias('b')).first()['b'] == '20190108'

    exp_8d = Experiment('experiment-with-8-day-cohort', '20190101', 8)
    view_method_8d = _get_enrollment_view("experiment-with-8-day-cohort")
    enrl_8d = exp_8d.get_enrollments(spark, view_method_8d)
    assert enrl_8d.count() == 3
    assert enrl_8d.select(F.max(
        enrl_8d.enrollment_date).alias('b')).first()['b'] == '20190108'
Beispiel #21
0
def test_query_not_detectably_malformed_fenix_fallback():
    exp = Experiment("slug", "2019-01-01", 8)

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    enrollments_sql = exp.build_enrollments_query(
        time_limits=tl, enrollments_query_type="fenix-fallback")

    sql_lint(enrollments_sql)

    metrics_sql = exp.build_metrics_query(
        metric_list=[],
        time_limits=tl,
        enrollments_table="enrollments",
    )

    sql_lint(metrics_sql)
Beispiel #22
0
def test_metrics_query_with_exposure_signal_custom_windows():
    exp = Experiment("slug", "2019-01-01", 8)

    tl = TimeLimits.for_ts(
        first_enrollment_date="2019-01-01",
        last_date_full_data="2019-03-01",
        time_series_period="weekly",
        num_dates_enrollment=8,
    )

    enrollments_sql = exp.build_enrollments_query(
        time_limits=tl, enrollments_query_type="fenix-fallback")

    sql_lint(enrollments_sql)

    metrics_sql = exp.build_metrics_query(
        metric_list=[
            m for m in mozanalysis.metrics.fenix.__dict__.values()
            if isinstance(m, Metric)
        ],
        time_limits=tl,
        enrollments_table="enrollments",
        analysis_basis=AnalysisBasis.EXPOSURES,
        exposure_signal=ExposureSignal(
            name="exposures",
            data_source=mozanalysis.metrics.fenix.baseline,
            select_expr="metrics.counter.events_total_uri_count > 0",
            friendly_name="URI visited exposure",
            description="Exposed when URI visited",
            window_start=1,
            window_end=3,
        ),
    )

    sql_lint(metrics_sql)

    assert "DATE_ADD('2019-01-01', INTERVAL 1 DAY)" in metrics_sql
    assert "DATE_ADD('2019-01-01', INTERVAL 3 DAY)" in metrics_sql
Beispiel #23
0
def test_process_metrics_dupe_data_source(spark):
    exp = Experiment('a-stub', '20190101', num_dates_enrollment=8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))

    ds_df = register_data_source_fixture(spark, name='ds_df_A')

    ds_1 = DataSource.from_dataframe('ds_df_A', ds_df)
    ds_2 = DataSource.from_dataframe('ds_df_A', ds_df)

    m1 = Metric.from_col('m1', ds_df.numeric_col, ds_1)
    m2 = Metric.from_col('m2', ds_df.bool_col, ds_2)

    metric_list = [m1, m2]

    exp = Experiment('a-stub', '20190101')

    data_sources_and_metrics = exp._process_metrics(enrollments, metric_list)

    assert len(data_sources_and_metrics) == 1

    assert len(data_sources_and_metrics[ds_df]) == 2
Beispiel #24
0
def test_get_per_client_data_join(spark):
    exp = Experiment('a-stub', '20190101')

    enrollments = spark.createDataFrame(
        [
            ['aaaa', 'control', '20190101'],
            ['bbbb', 'test', '20190101'],
            ['cccc', 'control', '20190108'],
            ['dddd', 'test', '20190109'],
            ['annie-nodata', 'control', '20190101'],
            ['bob-badtiming', 'test', '20190102'],
            ['carol-gooddata', 'test', '20190101'],
            ['derek-lateisok', 'control', '20190110'],
        ],
        [
            "client_id",
            "branch",
            "enrollment_date",
        ],
    )

    ex_d = {'a-stub': 'fake-branch-lifes-too-short'}
    data_source_df = spark.createDataFrame(
        [
            # bob-badtiming only has data before/after analysis window
            # but missed by `process_data_source`
            ['bob-badtiming', '20190102', ex_d, 1],
            ['bob-badtiming', '20190106', ex_d, 2],
            # carol-gooddata has data on two days (including a dupe day)
            ['carol-gooddata', '20190102', ex_d, 3],
            ['carol-gooddata', '20190102', ex_d, 2],
            ['carol-gooddata', '20190104', ex_d, 6],
            # derek-lateisok has data before and during the analysis window
            ['derek-lateisok', '20190110', ex_d, 1000],
            ['derek-lateisok', '20190111', ex_d, 1],
            # TODO: exercise the last condition on the join
        ],
        [
            "client_id",
            "submission_date_s3",
            "experiments",
            "some_value",
        ],
    )

    ds = DataSource.from_dataframe('ds', data_source_df)
    metric = Metric.from_col('some_value', agg_sum(data_source_df.some_value),
                             ds)

    res = exp.get_per_client_data(enrollments, [metric],
                                  '20190114',
                                  1,
                                  3,
                                  keep_client_id=True)

    # Check that the dataframe has the correct number of rows
    assert res.count() == enrollments.count()

    # Check that dataless enrollments are handled correctly
    annie_nodata = res.filter(res.client_id == 'annie-nodata')
    assert annie_nodata.count() == 1
    assert annie_nodata.first()['some_value'] == 0

    # Check that early and late data were ignored
    # i.e. check the join, not just _process_data_source_df
    bob_badtiming = res.filter(res.client_id == 'bob-badtiming')
    assert bob_badtiming.count() == 1
    assert bob_badtiming.first()['some_value'] == 0
    # Check that _process_data_source_df didn't do the
    # heavy lifting above
    time_limits = TimeLimits.for_single_analysis_window(
        exp.start_date, '20190114', 1, 3, exp.num_dates_enrollment)
    pds = exp._process_data_source_df(data_source_df, time_limits)
    assert pds.filter(pds.client_id == 'bob-badtiming').select(
        F.sum(pds.some_value).alias('agg_val')).first()['agg_val'] == 3

    # Check that relevant data was included appropriately
    carol_gooddata = res.filter(res.client_id == 'carol-gooddata')
    assert carol_gooddata.count() == 1
    assert carol_gooddata.first()['some_value'] == 11

    derek_lateisok = res.filter(res.client_id == 'derek-lateisok')
    assert derek_lateisok.count() == 1
    assert derek_lateisok.first()['some_value'] == 1

    # Check that it still works for `data_source`s without an experiments map
    ds_df_noexp = data_source_df.drop('experiments')
    ds_noexp = DataSource.from_dataframe('ds_noexp', ds_df_noexp)
    metric_noexp = Metric.from_col('some_value',
                                   agg_sum(ds_df_noexp.some_value), ds_noexp)

    res2 = exp.get_per_client_data(enrollments, [metric_noexp],
                                   '20190114',
                                   1,
                                   3,
                                   keep_client_id=True)

    assert res2.count() == enrollments.count()
Beispiel #25
0
def test_get_results_for_one_data_source(spark):
    exp = Experiment('a-stub', '20190101')

    enrollments = spark.createDataFrame(
        [
            ['aaaa', 'control', '20190101'],
            ['bbbb', 'test', '20190101'],
            ['cccc', 'control', '20190108'],
            ['dddd', 'test', '20190109'],
            ['annie-nodata', 'control', '20190101'],
            ['bob-badtiming', 'test', '20190102'],
            ['carol-gooddata', 'test', '20190101'],
            ['derek-lateisok', 'control', '20190110'],
        ],
        [
            "client_id",
            "branch",
            "enrollment_date",
        ],
    )

    ex_d = {'a-stub': 'fake-branch-lifes-too-short'}
    data_source = spark.createDataFrame(
        [
            # bob-badtiming only has data before/after analysis window
            # but missed by `process_data_source`
            ['bob-badtiming', '20190102', ex_d, 1],
            ['bob-badtiming', '20190106', ex_d, 2],
            # carol-gooddata has data on two days (including a dupe day)
            ['carol-gooddata', '20190102', ex_d, 3],
            ['carol-gooddata', '20190102', ex_d, 2],
            ['carol-gooddata', '20190104', ex_d, 6],
            # derek-lateisok has data before and during the analysis window
            ['derek-lateisok', '20190110', ex_d, 1000],
            ['derek-lateisok', '20190111', ex_d, 1],
            # TODO: exercise the last condition on the join
        ],
        [
            "client_id",
            "submission_date_s3",
            "experiments",
            "some_value",
        ],
    )

    time_limits = TimeLimits.for_single_analysis_window(
        exp.start_date,
        '20190114',
        1,
        3,
    )

    enrollments = exp._add_analysis_windows_to_enrollments(
        enrollments, time_limits)

    res = exp._get_results_for_one_data_source(
        enrollments,
        data_source,
        [
            F.coalesce(F.sum(data_source.some_value),
                       F.lit(0)).alias('some_value'),
        ],
    )

    # Check that the dataframe has the correct number of rows
    assert res.count() == enrollments.count()

    # Check that dataless enrollments are handled correctly
    annie_nodata = res.filter(res.client_id == 'annie-nodata')
    assert annie_nodata.count() == 1
    assert annie_nodata.first()['some_value'] == 0

    # Check that early and late data were ignored
    bob_badtiming = res.filter(res.client_id == 'bob-badtiming')
    assert bob_badtiming.count() == 1
    assert bob_badtiming.first()['some_value'] == 0

    # Check that relevant data was included appropriately
    carol_gooddata = res.filter(res.client_id == 'carol-gooddata')
    assert carol_gooddata.count() == 1
    assert carol_gooddata.first()['some_value'] == 11

    derek_lateisok = res.filter(res.client_id == 'derek-lateisok')
    assert derek_lateisok.count() == 1
    assert derek_lateisok.first()['some_value'] == 1

    # Check that it still works for `data_source`s without an experiments map
    res2 = exp._get_results_for_one_data_source(
        enrollments,
        data_source.drop('experiments'),
        [
            F.coalesce(F.sum(data_source.some_value),
                       F.lit(0)).alias('some_value'),
        ],
    )

    assert res2.count() == enrollments.count()