Beispiel #1
0
    def test_gen_cats_from_items(self):
        spark = OrcaContext.get_spark_session()
        sc = OrcaContext.get_spark_context()
        data = [
            ("jack", [1, 2, 3, 4, 5]),
            ("alice", [4, 5, 6, 7, 8]),
            ("rose", [1, 2])]
        schema = StructType([
            StructField("name", StringType(), True),
            StructField("item_hist_seq", ArrayType(IntegerType()), True)])

        df = spark.createDataFrame(data, schema)
        df.filter("name like '%alice%'").show()

        df2 = sc \
            .parallelize([(0, 0), (1, 0), (2, 0), (3, 0), (4, 1), (5, 1), (6, 1), (8, 2), (9, 2)]) \
            .toDF(["item", "category"]).withColumn("item", col("item").cast("Integer")) \
            .withColumn("category", col("category").cast("Integer"))
        tbl = FeatureTable(df)
        tbl2 = tbl.add_neg_hist_seq(9, "item_hist_seq", 4)
        tbl3 = tbl2.add_feature(["item_hist_seq", "neg_item_hist_seq"], FeatureTable(df2), 5)
        assert tbl3.df.select("category_hist_seq").count() == 3
        assert tbl3.df.select("neg_category_hist_seq").count() == 3
        assert tbl3.df.filter("name like '%alice%'").select("neg_category_hist_seq").count() == 1
        assert tbl3.df.filter("name == 'rose'").select("neg_category_hist_seq").count() == 1
 def test_add_hist_seq(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", 1, "2019-07-01 12:01:19.000"),
             ("jack", 2, "2019-08-01 12:01:19.000"),
             ("jack", 3, "2019-09-01 12:01:19.000"),
             ("jack", 4, "2019-07-02 12:01:19.000"),
             ("jack", 5, "2019-08-03 12:01:19.000"),
             ("jack", 6, "2019-07-04 12:01:19.000"),
             ("jack", 7, "2019-08-05 12:01:19.000"),
             ("alice", 4, "2019-09-01 12:01:19.000"),
             ("alice", 5, "2019-10-01 12:01:19.000"),
             ("alice", 6, "2019-11-01 12:01:19.000")]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("item", IntegerType(), True),
         StructField("time", StringType(), True)
     ])
     df = spark.createDataFrame(data=data, schema=schema)
     df = df.withColumn("ts", col("time").cast("timestamp").cast("long"))
     tbl = FeatureTable(df.select("name", "item", "ts")) \
         .add_hist_seq("name", ["item"], "ts", 1, 4)
     assert tbl.size() == 8
     assert tbl.df.filter(col("name") == "alice").count() == 2
     assert tbl.df.filter("name like '%jack'").count() == 6
     assert "item_hist_seq" in tbl.df.columns
Beispiel #3
0
 def test_cast(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", "123", 14, 8),
             ("alice", "34", 25, 9),
             ("rose", "25344", 23, 10)]
     schema = StructType([StructField("name", StringType(), True),
                          StructField("a", StringType(), True),
                          StructField("b", IntegerType(), True),
                          StructField("c", IntegerType(), True)])
     df = spark.createDataFrame(data, schema)
     tbl = FeatureTable(df)
     tbl = tbl.cast("a", "int")
     assert dict(tbl.df.dtypes)['a'] == "int", "column a should be now be cast to integer type"
     tbl = tbl.cast("a", "float")
     assert dict(tbl.df.dtypes)['a'] == "float", "column a should be now be cast to float type"
     tbl = tbl.cast(["b", "c"], "double")
     assert dict(tbl.df.dtypes)['b'] == dict(tbl.df.dtypes)['c'] == "double", \
         "column b and c should be now be cast to double type"
     tbl = tbl.cast(None, "float")
     assert dict(tbl.df.dtypes)['name'] == dict(tbl.df.dtypes)['a'] == dict(tbl.df.dtypes)['b'] \
         == dict(tbl.df.dtypes)['c'] == "float", \
         "all the columns should now be cast to float type"
     with self.assertRaises(Exception) as context:
         tbl = tbl.cast("a", "notvalid")
     self.assertTrue(
         "type should be string, boolean, int, long, short, float, double."
         in str(context.exception))
Beispiel #4
0
 def test_to_list(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", "123", 14, 8.5, [0, 0]),
             ("alice", "34", 25, 9.6, [1, 1]),
             ("rose", "25344", 23, 10.0, [2, 2])]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("num", StringType(), True),
         StructField("age", IntegerType(), True),
         StructField("height", DoubleType(), True),
         StructField("array", ArrayType(IntegerType()), True)
     ])
     tbl = FeatureTable(spark.createDataFrame(data, schema))
     list1 = tbl.to_list("name")
     list2 = tbl.to_list("num")
     list3 = tbl.to_list("age")
     list4 = tbl.to_list("height")
     list5 = tbl.to_list("array")
     assert list1 == ["jack", "alice",
                      "rose"], "the result of name is not correct"
     assert list2 == ["123", "34",
                      "25344"], "the result of num is not correct"
     assert list3 == [14, 25, 23], "the result of age is not correct"
     assert list4 == [8.5, 9.6, 10.0], "the result of height is not correct"
     assert list5 == [[0, 0], [1, 1],
                      [2, 2]], "the result of array is not correct"
Beispiel #5
0
 def test_cross_hash_encode(self):
     spark = OrcaContext.get_spark_session()
     data = [("a", "b", "c", 1), ("b", "a", "d", 2), ("a", "c", "e", 3),
             ("c", "c", "c", 2), ("b", "a", "d", 1), ("a", "d", "e", 1)]
     schema = StructType([
         StructField("A", StringType(), True),
         StructField("B", StringType(), True),
         StructField("C", StringType(), True),
         StructField("D", IntegerType(), True)
     ])
     df = spark.createDataFrame(data, schema)
     cross_hash_df = df.withColumn("A_B_C", concat("A", "B", "C"))
     tbl = FeatureTable(df)
     cross_hash_str = lambda x: hashlib.md5(
         str(x).encode('utf-8', 'strict')).hexdigest()
     cross_hash_int = lambda x: int(cross_hash_str(x), 16) % 100
     cross_hash_value = []
     for row in cross_hash_df.collect():
         cross_hash_value.append(cross_hash_int(row[4]))
     tbl_cross_hash = []
     for record in tbl.cross_hash_encode(["A", "B", "C"],
                                         100).to_spark_df().collect():
         tbl_cross_hash.append(int(record[4]))
     assert(operator.eq(cross_hash_value, tbl_cross_hash)), "the crossed hash encoded value" \
                                                            "should be equal"
Beispiel #6
0
 def test_get_stats(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", "123", 14, 8.5), ("alice", "34", 25, 9.7),
             ("rose", "25344", 23, 10.0)]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("num", StringType(), True),
         StructField("age", IntegerType(), True),
         StructField("height", DoubleType(), True)
     ])
     tbl = FeatureTable(spark.createDataFrame(data, schema))
     columns = ["age", "height"]
     # test str
     statistics = tbl.get_stats(columns, "min")
     assert len(statistics) == 2, "the dict should contain two statistics"
     assert statistics["age"] == 14, "the min value of age is not correct"
     assert statistics[
         "height"] == 8.5, "the min value of height is not correct"
     columns = ["age", "height"]
     # test dict
     statistics = tbl.get_stats(columns, {"age": "max", "height": "avg"})
     assert len(statistics) == 2, "the dict should contain two statistics"
     assert statistics["age"] == 25, "the max value of age is not correct"
     assert statistics[
         "height"] == 9.4, "the avg value of height is not correct"
     # test list
     statistics = tbl.get_stats(columns, ["min", "max"])
     assert len(statistics) == 2, "the dict should contain two statistics"
     assert statistics["age"][
         0] == 14, "the min value of age is not correct"
     assert statistics["age"][
         1] == 25, "the max value of age is not correct"
     assert statistics["height"][
         0] == 8.5, "the min value of height is not correct"
     assert statistics["height"][
         1] == 10.0, "the max value of height is not correct"
     # test dict of list
     statistics = tbl.get_stats(columns, {
         "age": ["min", "max"],
         "height": ["min", "avg"]
     })
     assert len(statistics) == 2, "the dict should contain two statistics"
     assert statistics["age"][
         0] == 14, "the min value of age is not correct"
     assert statistics["age"][
         1] == 25, "the max value of age is not correct"
     assert statistics["height"][
         0] == 8.5, "the min value of height is not correct"
     assert statistics["height"][
         1] == 9.4, "the max value of height is not correct"
     statistics = tbl.get_stats(None, "min")
     assert len(statistics) == 2, "the dict should contain two statistics"
     assert statistics["age"] == 14, "the min value of age is not correct"
     assert statistics[
         "height"] == 8.5, "the min value of height is not correct"
 def test_sample(self):
     spark = OrcaContext.get_spark_session()
     df = spark.range(1000)
     feature_tbl = FeatureTable(df)
     total_line_1 = feature_tbl.size()
     feature_tbl2 = feature_tbl.sample(0.5)
     total_line_2 = feature_tbl2.size()
     assert int(total_line_1/2) - 100 < total_line_2 < int(total_line_1/2) + 100, \
         "the number of rows should be half"
     total_distinct_line = feature_tbl2.distinct().size()
     assert total_line_2 == total_distinct_line, "all rows should be distinct"
Beispiel #8
0
 def test_filter_by_frequency(self):
     data = [("a", "b", 1), ("b", "a", 2), ("a", "bc", 3), ("c", "c", 2),
             ("b", "a", 2), ("ab", "c", 1), ("c", "b", 1), ("a", "b", 1)]
     schema = StructType([
         StructField("A", StringType(), True),
         StructField("B", StringType(), True),
         StructField("C", IntegerType(), True)
     ])
     spark = OrcaContext.get_spark_session()
     df = spark.createDataFrame(data, schema)
     tbl = FeatureTable(df).filter_by_frequency(["A", "B", "C"])
     assert tbl.to_spark_df().count(
     ) == 2, "the count of frequency >=2 should be 2"
Beispiel #9
0
 def _read_json(paths, cols):
     if not isinstance(paths, list):
         paths = [paths]
     spark = OrcaContext.get_spark_session()
     df = spark.read.json(paths)
     if cols:
         if isinstance(cols, list):
             df = df.select(*cols)
         elif isinstance(cols, str):
             df = df.select(cols)
         else:
             raise Exception("cols should be a column name or list of column names")
     return df
Beispiel #10
0
 def test_pad(self):
     spark = OrcaContext.get_spark_session()
     data = [
         ("jack", [1, 2, 3, 4, 5], [[1, 2, 3], [1, 2, 3]]),
         ("alice", [4, 5, 6, 7, 8], [[1, 2, 3], [1, 2, 3]]),
         ("rose", [1, 2], [[1, 2, 3]])]
     schema = StructType([StructField("name", StringType(), True),
                          StructField("list", ArrayType(IntegerType()), True),
                          StructField("matrix", ArrayType(ArrayType(IntegerType())))])
     df = spark.createDataFrame(data, schema)
     tbl = FeatureTable(df).pad(["list", "matrix"], 4)
     dft = tbl.df
     assert dft.filter("size(matrix) = 4").count() == 3
     assert dft.filter("size(list) = 4").count() == 3
    def test_mask(self):
        spark = OrcaContext.get_spark_session()
        data = [("jack", [1, 2, 3, 4, 5]), ("alice", [4, 5, 6, 7, 8]),
                ("rose", [1, 2])]
        schema = StructType([
            StructField("name", StringType(), True),
            StructField("history", ArrayType(IntegerType()), True)
        ])

        df = spark.createDataFrame(data, schema)
        tbl = FeatureTable(df).mask(["history"], 4)
        assert "history_mask" in tbl.df.columns
        assert tbl.df.filter("size(history_mask) = 4").count() == 3
        assert tbl.df.filter("size(history_mask) = 2").count() == 0
Beispiel #12
0
 def test_to_dict(self):
     spark = OrcaContext.get_spark_session()
     # test the case the column of key is unique
     data = [("jack", "123", 14), ("alice", "34", 25),
             ("rose", "25344", 23)]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("num", StringType(), True),
         StructField("age", IntegerType(), True)
     ])
     tbl = FeatureTable(spark.createDataFrame(data, schema))
     dictionary = tbl.to_dict()
     print(dictionary)
     assert dictionary["name"] == ['jack', 'alice', 'rose']
 def test_ordinal_shuffle(self):
     spark = OrcaContext.get_spark_session()
     data = [("a", 14), ("b", 25), ("c", 23), ("d", 2), ("e", 1)]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("num", IntegerType(), True)
     ])
     tbl = FeatureTable(spark.createDataFrame(data, schema).repartition(1))
     shuffled_tbl = tbl.ordinal_shuffle_partition()
     rows = tbl.df.collect()
     shuffled_rows = shuffled_tbl.df.collect()
     rows.sort(key=lambda x: x[1])
     shuffled_rows.sort(key=lambda x: x[1])
     assert rows == shuffled_rows
Beispiel #14
0
    def test_add_length(self):
        spark = OrcaContext.get_spark_session()
        data = [("jack", [1, 2, 3, 4, 5]),
                ("alice", [4, 5, 6, 7, 8]),
                ("rose", [1, 2])]
        schema = StructType([StructField("name", StringType(), True),
                             StructField("history", ArrayType(IntegerType()), True)])

        df = spark.createDataFrame(data, schema)
        tbl = FeatureTable(df)
        tbl = tbl.add_length("history")
        assert "history_length" in tbl.df.columns
        assert tbl.df.filter("history_length = 5").count() == 2
        assert tbl.df.filter("history_length = 2").count() == 1
Beispiel #15
0
 def test_max(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", "123", 14, 8.5), ("alice", "34", 25, 9.7),
             ("rose", "25344", 23, 10.0)]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("num", StringType(), True),
         StructField("age", IntegerType(), True),
         StructField("height", DoubleType(), True)
     ])
     tbl = FeatureTable(spark.createDataFrame(data, schema))
     columns = ["age", "height"]
     min_result = tbl.max(columns)
     assert min_result.to_list("max") == [25, 10.0], \
         "the maximum value for age and height is not correct"
Beispiel #16
0
 def test_add_negative_items(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", 1, "2019-07-01 12:01:19.000"),
             ("jack", 2, "2019-08-01 12:01:19.000"),
             ("jack", 3, "2019-09-01 12:01:19.000"),
             ("alice", 4, "2019-09-01 12:01:19.000"),
             ("alice", 5, "2019-10-01 12:01:19.000"),
             ("alice", 6, "2019-11-01 12:01:19.000")]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("item", IntegerType(), True),
         StructField("time", StringType(), True)
     ])
     df = spark.createDataFrame(data=data, schema=schema)
     tbl = FeatureTable(df).add_negative_samples(10)
     dft = tbl.df
     assert tbl.size() == 12
     assert dft.filter("label == 1").count() == 6
     assert dft.filter("label == 0").count() == 6
Beispiel #17
0
 def test_pad(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", [1, 2, 3, 4, 5], [[1, 2, 3], [1, 2, 3]]),
             ("alice", [4, 5, 6, 7, 8], [[1, 2, 3], [1, 2, 3]]),
             ("rose", [1, 2], [[1, 2, 3]])]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("list", ArrayType(IntegerType()), True),
         StructField("matrix", ArrayType(ArrayType(IntegerType())))
     ])
     df = spark.createDataFrame(data, schema)
     tbl1 = FeatureTable(df).pad(["list", "matrix"], seq_len=4)
     dft1 = tbl1.df
     tbl2 = FeatureTable(df).pad(cols=["list", "matrix"],
                                 mask_cols=["list"],
                                 seq_len=4)
     assert dft1.filter("size(matrix) = 4").count() == 3
     assert dft1.filter("size(list) = 4").count() == 3
     assert tbl2.df.filter("size(list_mask) = 4").count() == 3
     assert tbl2.df.filter("size(list_mask) = 2").count() == 0
     assert "list_mask" in tbl2.df.columns
Beispiel #18
0
 def test_encode_string_from_dict(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", "123", 14, 8),
             ("alice", "34", 25, 9),
             ("rose", "25344", 23, 10)]
     schema = StructType([StructField("name", StringType(), True),
                          StructField("num", StringType(), True),
                          StructField("age", IntegerType(), True),
                          StructField("height", IntegerType(), True)])
     tbl = FeatureTable(spark.createDataFrame(data, schema))
     columns = ["name", "num"]
     indices = []
     indices.append({"jack": 1, "alice": 2, "rose": 3})
     indices.append({"123": 3, "34": 1, "25344": 2})
     tbl = tbl.encode_string(columns, indices)
     assert 'name' in tbl.df.columns, "name should be still in the columns"
     assert 'num' in tbl.df.columns, "num should be still in the columns"
     assert tbl.df.where(tbl.df.age == 14).select("name").collect()[0]["name"] == 1, \
         "the first row of name should be 1"
     assert tbl.df.where(tbl.df.height == 10).select("num").collect()[0]["num"] == 2, \
         "the third row of num should be 2"
Beispiel #19
0
    def from_dict(cls, indices, col_name):
        """
        Create the StringIndex from a dict of indices.

        :param indices: dict. The key is the categorical column,
                        the value is the corresponding index.
                        We assume that the key is a str and the value is a int.
        :param col_name: str. The column name of the categorical column.

        :return: A StringIndex.
        """
        spark = OrcaContext.get_spark_session()
        if not isinstance(indices, dict):
            raise ValueError('indices should be dict, but get ' + indices.__class__.__name__)
        if not col_name:
            raise ValueError('col_name should be str, but get None')
        if not isinstance(col_name, str):
            raise ValueError('col_name should be str, but get ' + col_name.__class__.__name__)
        indices = map(lambda x: {col_name: x[0], 'id': x[1]}, indices.items())
        df = spark.createDataFrame(Row(**x) for x in indices)
        return cls(df, col_name)
Beispiel #20
0
 def test_hash_encode(self):
     spark = OrcaContext.get_spark_session()
     data = [("a", "b", 1), ("b", "a", 2), ("a", "c", 3), ("c", "c", 2),
             ("b", "a", 1), ("a", "d", 1)]
     schema = StructType([
         StructField("A", StringType(), True),
         StructField("B", StringType(), True),
         StructField("C", IntegerType(), True)
     ])
     df = spark.createDataFrame(data, schema)
     tbl = FeatureTable(df)
     hash_str = lambda x: hashlib.md5(str(x).encode('utf-8', 'strict')
                                      ).hexdigest()
     hash_int = lambda x: int(hash_str(x), 16) % 100
     hash_value = []
     for row in df.collect():
         hash_value.append(hash_int(row[0]))
     tbl_hash = []
     for record in tbl.hash_encode(["A"], 100).to_spark_df().collect():
         tbl_hash.append(int(record[0]))
     assert (operator.eq(
         hash_value, tbl_hash)), "the hash encoded value should be equal"
 def test_add(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", "123", 14, 8.5), ("alice", "34", 25, 9.6),
             ("rose", "25344", 23, 10.0)]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("num", StringType(), True),
         StructField("age", IntegerType(), True),
         StructField("height", DoubleType(), True)
     ])
     tbl = FeatureTable(spark.createDataFrame(data, schema))
     columns = ["age", "height"]
     new_tbl = tbl.add(columns, 1.5)
     new_list = new_tbl.df.take(3)
     assert len(new_list) == 3, "new_tbl should have 3 rows"
     assert new_list[0][
         'age'] == 15.5, "the age of jack should increase 1.5"
     assert new_list[0][
         'height'] == 10, "the height of jack should increase 1.5"
     assert new_list[1][
         'age'] == 26.5, "the age of alice should increase 1.5"
     assert new_list[1][
         'height'] == 11.1, "the height of alice should increase 1.5"
     assert new_list[2][
         'age'] == 24.5, "the age of rose should increase 1.5"
     assert new_list[2][
         'height'] == 11.5, "the height of rose should increase 1.5"
     new_tbl = tbl.add(columns, -1)
     new_list = new_tbl.df.take(3)
     assert len(new_list) == 3, "new_tbl should have 3 rows"
     assert new_list[0]['age'] == 13, "the age of jack should decrease 1"
     assert new_list[0][
         'height'] == 7.5, "the height of jack should decrease 1"
     assert new_list[1]['age'] == 24, "the age of alice should decrease 1"
     assert new_list[1][
         'height'] == 8.6, "the height of alice should decrease 1"
     assert new_list[2]['age'] == 22, "the age of rose should decrease 1"
     assert new_list[2][
         'height'] == 9.0, "the height of rose should decrease 1"
    def test_string_input(self):
        def model_creator(config):
            import tensorflow as tf
            vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
                max_tokens=10, output_mode='int', output_sequence_length=4)
            model = tf.keras.models.Sequential()
            model.add(tf.keras.Input(shape=(1, ), dtype=tf.string))
            model.add(vectorize_layer)
            return model

        from zoo.orca import OrcaContext
        from pyspark.sql.types import StructType, StructField, StringType
        spark = OrcaContext.get_spark_session()
        schema = StructType([StructField("input", StringType(), True)])
        input_data = [["foo qux bar"], ["qux baz"]]
        input_df = spark.createDataFrame(input_data, schema)
        estimator = Estimator.from_keras(model_creator=model_creator)
        output_df = estimator.predict(input_df,
                                      batch_size=1,
                                      feature_cols=["input"])
        output = output_df.collect()
        print(output)
Beispiel #23
0
def read_parquet(file_path, columns=None, schema=None, **options):
    """

    Read parquet files to SparkXShards of pandas DataFrames.

    :param file_path: Parquet file path, a list of multiple parquet file paths, or a directory
    containing parquet files. Local file system, HDFS, and AWS S3 are supported.
    :param columns: list of column name, default=None.
    If not None, only these columns will be read from the file.
    :param schema: pyspark.sql.types.StructType for the input schema or
    a DDL-formatted string (For example col0 INT, col1 DOUBLE).
    :param options: other options for reading parquet.
    :return: An instance of SparkXShards.
    """
    sc = init_nncontext()
    spark = OrcaContext.get_spark_session()
    # df = spark.read.parquet(file_path)
    df = spark.read.load(file_path, "parquet", schema=schema, **options)

    if columns:
        df = df.select(*columns)

    def to_pandas(columns):
        def f(iter):
            import pandas as pd
            data = list(iter)
            pd_df = pd.DataFrame(data, columns=columns)
            return [pd_df]

        return f

    pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns))
    try:
        data_shards = SparkXShards(pd_rdd)
    except Exception as e:
        print("An error occurred when reading parquet files")
        raise e
    return data_shards
Beispiel #24
0
    def read_df(esConfig, esResource, schema=None):
        """
        Read the data from elastic search into DataFrame.

        :param esConfig: Dictionary which represents configuration for
               elastic search(eg. ip, port etc).
        :param esResource: resource file in elastic search.
        :param schema: Optional. Defines the schema of Spark dataframe.
                If each column in Es is single value, don't need set schema.
        :return: Spark DataFrame. Each row represents a document in ES.
        """
        sc = init_nncontext()
        spark = OrcaContext.get_spark_session()

        reader = spark.read.format("org.elasticsearch.spark.sql")

        for key in esConfig:
            reader.option(key, esConfig[key])
        if schema:
            reader.schema(schema)

        df = reader.load(esResource)
        return df
Beispiel #25
0
    def test_read_parquet(self):
        file_path = os.path.join(self.resource_path, "orca/data/csv")
        sc = init_nncontext()
        from pyspark.sql.functions import col
        spark = OrcaContext.get_spark_session()
        df = spark.read.csv(file_path, header=True)
        df = df.withColumn('sale_price', col('sale_price').cast('int'))
        temp = tempfile.mkdtemp()
        df.write.parquet(os.path.join(temp, "test_parquet"))
        data_shard2 = zoo.orca.data.pandas.read_parquet(
            os.path.join(temp, "test_parquet"))
        assert data_shard2.num_partitions() == 2, "number of shard should be 2"
        data = data_shard2.collect()
        df = data[0]
        assert "location" in df.columns

        data_shard2 = zoo.orca.data.pandas.read_parquet(
            os.path.join(temp, "test_parquet"), columns=['ID', 'sale_price'])
        data = data_shard2.collect()
        df = data[0]
        assert len(df.columns) == 2

        from pyspark.sql.types import StructType, StructField, IntegerType, StringType
        schema = StructType([
            StructField("ID", StringType(), True),
            StructField("sale_price", IntegerType(), True),
            StructField("location", StringType(), True)
        ])
        data_shard3 = zoo.orca.data.pandas.read_parquet(
            os.path.join(temp, "test_parquet"),
            columns=['ID', 'sale_price'],
            schema=schema)
        data = data_shard3.collect()
        df = data[0]
        assert str(df['sale_price'].dtype) == 'int64'

        shutil.rmtree(temp)
Beispiel #26
0
 def _read_parquet(paths):
     if not isinstance(paths, list):
         paths = [paths]
     spark = OrcaContext.get_spark_session()
     df = spark.read.parquet(*paths)
     return df
Beispiel #27
0
def read_file_spark(file_path, file_type, **kwargs):
    sc = init_nncontext()
    node_num, core_num = get_node_and_core_number()
    backend = OrcaContext.pandas_read_backend

    if backend == "pandas":
        file_url_splits = file_path.split("://")
        prefix = file_url_splits[0]

        file_paths = []
        if isinstance(file_path, list):
            [file_paths.extend(extract_one_path(path, os.environ)) for path in file_path]
        else:
            file_paths = extract_one_path(file_path, os.environ)

        if not file_paths:
            raise Exception("The file path is invalid or empty, please check your data")

        num_files = len(file_paths)
        total_cores = node_num * core_num
        num_partitions = num_files if num_files < total_cores else total_cores
        rdd = sc.parallelize(file_paths, num_partitions)

        if prefix == "hdfs":
            pd_rdd = rdd.mapPartitions(
                lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs))
        elif prefix == "s3":
            pd_rdd = rdd.mapPartitions(
                lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs))
        else:
            def loadFile(iterator):
                dfs = []
                for x in iterator:
                    df = read_pd_file(x, file_type, **kwargs)
                    dfs.append(df)
                import pandas as pd
                return [pd.concat(dfs)]

            pd_rdd = rdd.mapPartitions(loadFile)
    else:  # Spark backend; spark.read.csv/json accepts a folder path as input
        assert file_type == "json" or file_type == "csv", \
            "Unsupported file type: %s. Only csv and json files are supported for now" % file_type
        spark = OrcaContext.get_spark_session()
        # TODO: add S3 confidentials

        # The following implementation is adapted from
        # https://github.com/databricks/koalas/blob/master/databricks/koalas/namespace.py
        # with some modifications.

        if "mangle_dupe_cols" in kwargs:
            assert kwargs["mangle_dupe_cols"], "mangle_dupe_cols can only be True"
            kwargs.pop("mangle_dupe_cols")
        if "parse_dates" in kwargs:
            assert not kwargs["parse_dates"], "parse_dates can only be False"
            kwargs.pop("parse_dates")

        names = kwargs.get("names", None)
        if "names" in kwargs:
            kwargs.pop("names")
        usecols = kwargs.get("usecols", None)
        if "usecols" in kwargs:
            kwargs.pop("usecols")
        dtype = kwargs.get("dtype", None)
        if "dtype" in kwargs:
            kwargs.pop("dtype")
        squeeze = kwargs.get("squeeze", False)
        if "squeeze" in kwargs:
            kwargs.pop("squeeze")
        index_col = kwargs.get("index_col", None)
        if "index_col" in kwargs:
            kwargs.pop("index_col")

        if file_type == "csv":
            # Handle pandas-compatible keyword arguments
            kwargs["inferSchema"] = True
            header = kwargs.get("header", "infer")
            if isinstance(names, str):
                kwargs["schema"] = names
            if header == "infer":
                header = 0 if names is None else None
            if header == 0:
                kwargs["header"] = True
            elif header is None:
                kwargs["header"] = False
            else:
                raise ValueError("Unknown header argument {}".format(header))
            if "quotechar" in kwargs:
                quotechar = kwargs["quotechar"]
                kwargs.pop("quotechar")
                kwargs["quote"] = quotechar
            if "escapechar" in kwargs:
                escapechar = kwargs["escapechar"]
                kwargs.pop("escapechar")
                kwargs["escape"] = escapechar
            # sep and comment are the same as pandas
            if "comment" in kwargs:
                comment = kwargs["comment"]
                if not isinstance(comment, str) or len(comment) != 1:
                    raise ValueError("Only length-1 comment characters supported")
            df = spark.read.csv(file_path, **kwargs)
            if header is None:
                df = df.selectExpr(
                    *["`%s` as `%s`" % (field.name, i) for i, field in enumerate(df.schema)])
        else:
            df = spark.read.json(file_path, **kwargs)

        # Handle pandas-compatible postprocessing arguments
        if usecols is not None and not callable(usecols):
            usecols = list(usecols)
        renamed = False
        if isinstance(names, list):
            if len(set(names)) != len(names):
                raise ValueError("Found duplicate names, please check your names input")
            if usecols is not None:
                if not callable(usecols):
                    # usecols is list
                    if len(names) != len(usecols) and len(names) != len(df.schema):
                        raise ValueError(
                            "Passed names did not match usecols"
                        )
                if len(names) == len(df.schema):
                    df = df.selectExpr(
                        *["`%s` as `%s`" % (field.name, name) for field, name
                          in zip(df.schema, names)]
                    )
                    renamed = True

            else:
                if len(names) != len(df.schema):
                    raise ValueError(
                        "The number of names [%s] does not match the number "
                        "of columns [%d]. Try names by a Spark SQL DDL-formatted "
                        "string." % (len(names), len(df.schema))
                    )
                df = df.selectExpr(
                    *["`%s` as `%s`" % (field.name, name) for field, name
                      in zip(df.schema, names)]
                )
                renamed = True
        index_map = dict([(i, field.name) for i, field in enumerate(df.schema)])
        if usecols is not None:
            if callable(usecols):
                cols = [field.name for field in df.schema if usecols(field.name)]
                missing = []
            elif all(isinstance(col, int) for col in usecols):
                cols = [field.name for i, field in enumerate(df.schema) if i in usecols]
                missing = [
                    col
                    for col in usecols
                    if col >= len(df.schema) or df.schema[col].name not in cols
                ]
            elif all(isinstance(col, str) for col in usecols):
                cols = [field.name for field in df.schema if field.name in usecols]
                if isinstance(names, list):
                    missing = [c for c in usecols if c not in names]
                else:
                    missing = [col for col in usecols if col not in cols]
            else:
                raise ValueError(
                    "usecols must only be list-like of all strings, "
                    "all unicode, all integers or a callable.")
            if len(missing) > 0:
                raise ValueError(
                    "usecols do not match columns, columns expected but not found: %s" % missing)
            if len(cols) > 0:
                df = df.select(cols)
                if isinstance(names, list):
                    if not renamed:
                        df = df.selectExpr(
                            *["`%s` as `%s`" % (col, name) for col, name in zip(cols, names)]
                        )
                        # update index map after rename
                        for index, col in index_map.items():
                            if col in cols:
                                index_map[index] = names[cols.index(col)]

        if df.rdd.getNumPartitions() < node_num:
            df = df.repartition(node_num)

        def to_pandas(columns, squeeze=False, index_col=None):
            def f(iter):
                import pandas as pd
                data = list(iter)
                pd_df = pd.DataFrame(data, columns=columns)
                if dtype is not None:
                    if isinstance(dtype, dict):
                        for col, type in dtype.items():
                            if isinstance(col, str):
                                if col not in pd_df.columns:
                                    raise ValueError("column to be set type is not"
                                                     " in current dataframe")
                                pd_df[col] = pd_df[col].astype(type)
                            elif isinstance(col, int):
                                if index_map[col] not in pd_df.columns:
                                    raise ValueError("column index to be set type is not"
                                                     " in current dataframe")
                                pd_df[index_map[col]] = pd_df[index_map[col]].astype(type)
                    else:
                        pd_df = pd_df.astype(dtype)
                if squeeze and len(pd_df.columns) == 1:
                    pd_df = pd_df.iloc[:, 0]
                if index_col:
                    pd_df = pd_df.set_index(index_col)

                return [pd_df]

            return f

        pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns, squeeze, index_col))

    try:
        data_shards = SparkXShards(pd_rdd)
    except Exception as e:
        alternative_backend = "pandas" if backend == "spark" else "spark"
        print("An error occurred when reading files with '%s' backend, you may switch to '%s' "
              "backend for another try. You can set the backend using "
              "OrcaContext.pandas_read_backend" % (backend, alternative_backend))
        raise e
    return data_shards
from optparse import OptionParser
from zoo.orca import init_orca_context, stop_orca_context, OrcaContext
from pyspark.sql.functions import udf, col
from zoo.friesian.feature import FeatureTable, StringIndex
from pyspark.sql.types import StringType, IntegerType, ArrayType, FloatType

if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("--meta", dest="meta_file")
    parser.add_option("--review", dest="review_file")
    parser.add_option("--output", dest="output")
    (options, args) = parser.parse_args(sys.argv)
    begin = time.time()
    sc = init_orca_context("local")
    spark = OrcaContext.get_spark_session()

    # read review datavi run.sh
    transaction_df = spark.read.json(options.review_file).select(
        ['reviewerID', 'asin', 'unixReviewTime']) \
        .withColumnRenamed('reviewerID', 'user') \
        .withColumnRenamed('asin', 'item') \
        .withColumnRenamed('unixReviewTime', 'time')\
        .dropna("any").persist(storageLevel=StorageLevel.DISK_ONLY)
    transaction_tbl = FeatureTable(transaction_df)
    print("review_tbl, ", transaction_tbl.size())

    # read meta data
    def get_category(x):
        cat = x[0][-1] if x[0][-1] is not None else "default"
        return cat.strip().lower()