def test_reformat_generates_rows_with_all(simple_df):
    """ The output of the dataset should contain 2^3 values. The
    cardinality of each dimension is 2 because of the additional `all`
    label."""
    df = topline.reformat_data(simple_df)

    assert df.count() == 8
def test_reformat_filters_ROW(generate_data):
    # Maldives is not a target region
    input_df = generate_data([{"geo": "MV"}])
    df = topline.reformat_data(input_df)

    assert df.where("geo='MV'").count() == 0
    assert df.where("geo='Other'").count() > 0
def test_reformat_filters_ROW(spark):
    # Maldives is not a target region
    input_df = snippets_to_df(spark, [{'geo': 'MV'}],
                              default_sample, topline_schema)
    df = topline.reformat_data(input_df)

    assert df.where("geo='MV'").count() == 0
    assert df.where("geo='Other'").count() > 0
def test_reformat_prunes_empty_rows_with_all(multi_df):
    """ This test should generate 16 results where any of the rows
    contains `all` in any of the attribute fields. The cardinality of
    the cross product is 27. Don't include any rows that do not
    contain 'all'. Dont include rows that contain values of 0. This
    removes 2^3 results imediately, leaving 19 rows. We get rid of the
    extra three from the tuples containing only a single `all`.

    ('CA', 'release', 'all'),
    ('CA', 'all', 'Linux'),
    ('all', 'release','Linux')

    should not exist and contain empty rows. This leaves 16 results."""
    df = topline.reformat_data(multi_df)

    # This row should be pruned
    assert df.where("geo='CA' AND channel='release'").count() == 0

    # This should be the accurate count at the end
    assert df.where("geo='all' OR channel='all' OR os='all'").count() == 16
def test_reformat_conforms_to_historical_schema(simple_df):
    df = topline.reformat_data(simple_df)

    assert df.columns == historical_schema.names
def test_reformat_aggregates(multi_df):
    df = topline.reformat_data(multi_df)

    rows = df.where("geo='all' AND channel='all' AND os='all'").head()
    assert rows.hours == 3.0