def test_load_spark_df(
    size,
    num_samples,
    num_movies,
    movie_example,
    title_example,
    genres_example,
    year_example,
    tmp,
    spark,
):
    """Test MovieLens dataset load into pySpark.DataFrame"""

    # Test if correct data are loaded
    header = ["1", "2", "3"]
    schema = StructType([
        StructField("u", IntegerType()),
        StructField("m", IntegerType()),
    ])
    with pytest.warns(Warning):
        df = load_spark_df(spark,
                           size=size,
                           local_cache_path=tmp,
                           header=header,
                           schema=schema)
        assert df.count() == num_samples
        # Test if schema is used when both schema and header are provided
        assert len(df.columns) == len(schema)
        # Test if raw-zip file, rating file, and item file are cached
        assert len(os.listdir(tmp)) == 3

    # Test title, genres, and released year load
    header = ["a", "b", "c", "d", "e"]
    with pytest.warns(Warning):
        df = load_spark_df(
            spark,
            size=size,
            local_cache_path=tmp,
            header=header,
            title_col="Title",
            genres_col="Genres",
            year_col="Year",
        )
        assert df.count() == num_samples
        assert (
            len(df.columns) == 7
        )  # 4 header columns (user, item, rating, timestamp) and 3 feature columns
        assert "e" not in df.columns  # only the first 4 header columns are used
        # Get two records of the same items and check if the item-features are the same.
        head = df.filter(col("b") == movie_example).limit(2)
        title = head.select("Title").collect()
        assert title[0][0] == title[1][0]
        assert title[0][0] == title_example
        genres = head.select("Genres").collect()
        assert genres[0][0] == genres[1][0]
        assert genres[0][0] == genres_example
        year = head.select("Year").collect()
        assert year[0][0] == year[1][0]
        assert year[0][0] == year_example

    # Test default arguments
    df = load_spark_df(spark, size)
    assert df.count() == num_samples
    # user, item, rating and timestamp
    assert len(df.columns) == 4
def test_load_spark_df(size, num_samples, num_movies, title_example,
                       genres_example):
    """Test MovieLens dataset load into pySpark.DataFrame
    """
    spark = start_or_get_spark("MovieLensLoaderTesting")

    # Check if the function load correct dataset
    df = movielens.load_spark_df(spark, size=size)
    assert df.count() == num_samples
    assert len(df.columns) == 4

    # Test if can handle different size of header columns
    header = ["a"]
    df = movielens.load_spark_df(spark, header=header)
    assert len(df.columns) == len(header)

    header = ["a", "b", "c", "d", "e"]
    with pytest.warns(Warning):
        df = movielens.load_spark_df(spark, header=header)
        assert len(df.columns) == 4

    # Test title load
    df = movielens.load_spark_df(spark, size=size, title_col="Title")
    assert len(df.columns) == 5
    # Movie 1 is Toy Story
    title = df.filter(
        col(DEFAULT_ITEM_COL) == 1).select("Title").limit(2).collect()
    assert title[0][0] == title[1][0]
    assert title[0][0] == title_example

    # Test genres load
    df = movielens.load_spark_df(spark, size=size, genres_col="Genres")
    assert len(df.columns) == 5
    # Movie 1 is Toy Story
    genres = df.filter(
        col(DEFAULT_ITEM_COL) == 1).select("Genres").limit(2).collect()
    assert genres[0][0] == genres[1][0]
    assert genres[0][0] == genres_example

    # Test movie data load (not rating data)
    df = movielens.load_spark_df(spark,
                                 size=size,
                                 header=None,
                                 title_col="Title",
                                 genres_col="Genres")
    assert df.count() == num_movies
    assert len(df.columns) == 3

    # Test if can handle wrong size argument
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, size='10k')
    # Test if can handle wrong cache path argument
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, local_cache_path='.')

    # Test if use schema when both schema and header are provided
    header = ["1", "2"]
    schema = StructType([StructField("u", IntegerType())])
    with pytest.warns(Warning):
        df = movielens.load_spark_df(spark, header=header, schema=schema)
        assert len(df.columns) == len(schema)
def test_load_spark_df(
    size,
    num_samples,
    num_movies,
    movie_example,
    title_example,
    genres_example,
    year_example,
    tmp,
):
    """Test MovieLens dataset load into pySpark.DataFrame
    """
    spark = start_or_get_spark("MovieLensLoaderTesting")

    # Test if correct data are loaded
    header = ["1", "2", "3"]
    schema = StructType(
        [
            StructField("u", IntegerType()),
            StructField("m", IntegerType()),
        ]
    )
    with pytest.warns(Warning):
        df = load_spark_df(
            spark, size=size, local_cache_path=tmp, header=header, schema=schema
        )
        assert df.count() == num_samples
        # Test if schema is used when both schema and header are provided
        assert len(df.columns) == len(schema)
        # Test if raw-zip file, rating file, and item file are cached
        assert len(os.listdir(tmp)) == 3

    # Test title, genres, and released year load
    header = ["a", "b", "c", "d", "e"]
    with pytest.warns(Warning):
        df = load_spark_df(
            spark,
            size=size,
            local_cache_path=tmp,
            header=header,
            title_col="Title",
            genres_col="Genres",
            year_col="Year",
        )
        assert df.count() == num_samples
        assert (
            len(df.columns) == 7
        )  # 4 header columns (user, item, rating, timestamp) and 3 feature columns
        assert "e" not in df.columns  # only the first 4 header columns are used
        # Get two records of the same items and check if the item-features are the same.
        head = df.filter(col("b") == movie_example).limit(2)
        title = head.select("Title").collect()
        assert title[0][0] == title[1][0]
        assert title[0][0] == title_example
        genres = head.select("Genres").collect()
        assert genres[0][0] == genres[1][0]
        assert genres[0][0] == genres_example
        year = head.select("Year").collect()
        assert year[0][0] == year[1][0]
        assert year[0][0] == year_example

    # Test default arguments
    df = load_spark_df(spark, size)
    assert df.count() == num_samples
    # user, item, rating and timestamp
    assert len(df.columns) == 4
def test_load_spark_df():
    """Test MovieLens dataset load into pySpark.DataFrame
    """
    spark = start_or_get_spark("MovieLensLoaderTesting")

    # Check if the function load correct dataset
    size_100k = movielens.load_spark_df(spark, size="100k")
    assert size_100k.count() == 100000
    assert len(size_100k.columns) == 4
    size_1m = movielens.load_spark_df(spark, size="1m")
    assert size_1m.count() == 1000209
    assert len(size_1m.columns) == 4
    size_10m = movielens.load_spark_df(spark, size="10m")
    assert size_10m.count() == 10000054
    assert len(size_10m.columns) == 4
    size_20m = movielens.load_spark_df(spark, size="20m")
    assert size_20m.count() == 20000263
    assert len(size_20m.columns) == 4

    # Test if can handle wrong size argument
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, size='10k')
    # Test if can handle wrong cache path argument
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, local_cache_path='.')

    # Test if can handle different size of header columns
    header = ["a", "b", "c"]
    with_header = movielens.load_spark_df(spark, header=header)
    assert with_header.count() == 100000
    assert len(with_header.columns) == len(header)

    header = ["a", "b", "c", "d", "e"]
    with pytest.warns(Warning):
        with_header = movielens.load_spark_df(spark, header=header)
        assert with_header.count() == 100000
        assert len(with_header.columns) == 4

    # Test if can throw exception for wrong types
    schema = StructType([StructField("u", StringType())])
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, schema=schema)
    schema = StructType(
        [StructField("u", IntegerType()),
         StructField("i", StringType())])
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, schema=schema)
    schema = StructType([
        StructField("u", IntegerType()),
        StructField("i", IntegerType()),
        StructField("r", IntegerType()),
    ])
    with pytest.raises(ValueError):
        movielens.load_spark_df(spark, schema=schema)

    # Test if can handle different size of schema fields
    schema = StructType([
        StructField("u", IntegerType()),
        StructField("i", IntegerType()),
        StructField("r", FloatType()),
    ])
    with_schema = movielens.load_spark_df(spark, schema=schema)
    assert with_schema.count() == 100000
    assert len(with_schema.columns) == len(schema)
    schema = StructType([
        StructField("u", IntegerType()),
        StructField("i", IntegerType()),
        StructField("r", DoubleType()),
        StructField("a", IntegerType()),
        StructField("b", IntegerType()),
    ])
    with pytest.warns(Warning):
        with_schema = movielens.load_spark_df(spark, schema=schema)
        assert with_schema.count() == 100000
        assert len(with_schema.columns) == 4

    # Test if use schema when both schema and header are provided
    schema = StructType([StructField("u", IntegerType())])
    with pytest.warns(Warning):
        with_schema = movielens.load_spark_df(spark,
                                              header=header,
                                              schema=schema)
        assert with_schema.count() == 100000
        assert len(with_schema.columns) == len(schema)