Exemplo n.º 1
0
def test_process_metrics(spark):
    exp = Experiment('a-stub', '20190101', num_dates_enrollment=8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))

    ds_df_A = register_data_source_fixture(spark, name='ds_df_A')
    ds_df_B = register_data_source_fixture(spark, name='ds_df_B')

    ds_A = DataSource.from_dataframe('ds_df_A', ds_df_A)
    ds_B = DataSource.from_dataframe('ds_df_B', ds_df_B)

    m1 = Metric.from_col('m1', ds_df_A.numeric_col, ds_A)
    m2 = Metric.from_col('m2', ds_df_A.bool_col, ds_A)
    m3 = Metric.from_col('m3', ds_df_B.numeric_col, ds_B)

    metric_list = [m1, m2, m3]

    exp = Experiment('a-stub', '20190101')

    data_sources_and_metrics = exp._process_metrics(enrollments, metric_list)

    assert len(data_sources_and_metrics) == 2

    assert len(data_sources_and_metrics[ds_df_A]) == 2
    assert len(data_sources_and_metrics[ds_df_B]) == 1

    assert 'numeric_col' in repr(data_sources_and_metrics[ds_df_B][0])
    assert '`m3`' in repr(data_sources_and_metrics[ds_df_B][0])
    assert repr(data_sources_and_metrics[ds_df_B][0]) in {
        "Column<b'numeric_col AS `m3`'>",  # py3
        "Column<numeric_col AS `m3`>",  # py2
    }
Exemplo n.º 2
0
def _get_metrics(spark):
    ds_df = _get_data_source_df(spark)
    ds = DataSource.from_dataframe('bla_ds', ds_df)

    return {
        'how_many_ones':
        Metric.from_col('how_many_ones', agg_sum(ds_df.constant_one), ds),
    }
Exemplo n.º 3
0
def test_process_metrics_dupe_data_source(spark):
    exp = Experiment('a-stub', '20190101', num_dates_enrollment=8)
    enrollments = exp.get_enrollments(spark,
                                      _get_enrollment_view(slug="a-stub"))

    ds_df = register_data_source_fixture(spark, name='ds_df_A')

    ds_1 = DataSource.from_dataframe('ds_df_A', ds_df)
    ds_2 = DataSource.from_dataframe('ds_df_A', ds_df)

    m1 = Metric.from_col('m1', ds_df.numeric_col, ds_1)
    m2 = Metric.from_col('m2', ds_df.bool_col, ds_2)

    metric_list = [m1, m2]

    exp = Experiment('a-stub', '20190101')

    data_sources_and_metrics = exp._process_metrics(enrollments, metric_list)

    assert len(data_sources_and_metrics) == 1

    assert len(data_sources_and_metrics[ds_df]) == 2
Exemplo n.º 4
0
def test_get_per_client_data_join(spark):
    exp = Experiment('a-stub', '20190101')

    enrollments = spark.createDataFrame(
        [
            ['aaaa', 'control', '20190101'],
            ['bbbb', 'test', '20190101'],
            ['cccc', 'control', '20190108'],
            ['dddd', 'test', '20190109'],
            ['annie-nodata', 'control', '20190101'],
            ['bob-badtiming', 'test', '20190102'],
            ['carol-gooddata', 'test', '20190101'],
            ['derek-lateisok', 'control', '20190110'],
        ],
        [
            "client_id",
            "branch",
            "enrollment_date",
        ],
    )

    ex_d = {'a-stub': 'fake-branch-lifes-too-short'}
    data_source_df = spark.createDataFrame(
        [
            # bob-badtiming only has data before/after analysis window
            # but missed by `process_data_source`
            ['bob-badtiming', '20190102', ex_d, 1],
            ['bob-badtiming', '20190106', ex_d, 2],
            # carol-gooddata has data on two days (including a dupe day)
            ['carol-gooddata', '20190102', ex_d, 3],
            ['carol-gooddata', '20190102', ex_d, 2],
            ['carol-gooddata', '20190104', ex_d, 6],
            # derek-lateisok has data before and during the analysis window
            ['derek-lateisok', '20190110', ex_d, 1000],
            ['derek-lateisok', '20190111', ex_d, 1],
            # TODO: exercise the last condition on the join
        ],
        [
            "client_id",
            "submission_date_s3",
            "experiments",
            "some_value",
        ],
    )

    ds = DataSource.from_dataframe('ds', data_source_df)
    metric = Metric.from_col('some_value', agg_sum(data_source_df.some_value),
                             ds)

    res = exp.get_per_client_data(enrollments, [metric],
                                  '20190114',
                                  1,
                                  3,
                                  keep_client_id=True)

    # Check that the dataframe has the correct number of rows
    assert res.count() == enrollments.count()

    # Check that dataless enrollments are handled correctly
    annie_nodata = res.filter(res.client_id == 'annie-nodata')
    assert annie_nodata.count() == 1
    assert annie_nodata.first()['some_value'] == 0

    # Check that early and late data were ignored
    # i.e. check the join, not just _process_data_source_df
    bob_badtiming = res.filter(res.client_id == 'bob-badtiming')
    assert bob_badtiming.count() == 1
    assert bob_badtiming.first()['some_value'] == 0
    # Check that _process_data_source_df didn't do the
    # heavy lifting above
    time_limits = TimeLimits.for_single_analysis_window(
        exp.start_date, '20190114', 1, 3, exp.num_dates_enrollment)
    pds = exp._process_data_source_df(data_source_df, time_limits)
    assert pds.filter(pds.client_id == 'bob-badtiming').select(
        F.sum(pds.some_value).alias('agg_val')).first()['agg_val'] == 3

    # Check that relevant data was included appropriately
    carol_gooddata = res.filter(res.client_id == 'carol-gooddata')
    assert carol_gooddata.count() == 1
    assert carol_gooddata.first()['some_value'] == 11

    derek_lateisok = res.filter(res.client_id == 'derek-lateisok')
    assert derek_lateisok.count() == 1
    assert derek_lateisok.first()['some_value'] == 1

    # Check that it still works for `data_source`s without an experiments map
    ds_df_noexp = data_source_df.drop('experiments')
    ds_noexp = DataSource.from_dataframe('ds_noexp', ds_df_noexp)
    metric_noexp = Metric.from_col('some_value',
                                   agg_sum(ds_df_noexp.some_value), ds_noexp)

    res2 = exp.get_per_client_data(enrollments, [metric_noexp],
                                   '20190114',
                                   1,
                                   3,
                                   keep_client_id=True)

    assert res2.count() == enrollments.count()