コード例 #1
0
 def test_add_hist_seq(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", 1, "2019-07-01 12:01:19.000"),
             ("jack", 2, "2019-08-01 12:01:19.000"),
             ("jack", 3, "2019-09-01 12:01:19.000"),
             ("jack", 4, "2019-07-02 12:01:19.000"),
             ("jack", 5, "2019-08-03 12:01:19.000"),
             ("jack", 6, "2019-07-04 12:01:19.000"),
             ("jack", 7, "2019-08-05 12:01:19.000"),
             ("alice", 4, "2019-09-01 12:01:19.000"),
             ("alice", 5, "2019-10-01 12:01:19.000"),
             ("alice", 6, "2019-11-01 12:01:19.000")]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("item", IntegerType(), True),
         StructField("time", StringType(), True)
     ])
     df = spark.createDataFrame(data=data, schema=schema)
     df = df.withColumn("ts", col("time").cast("timestamp").cast("long"))
     tbl = FeatureTable(df.select("name", "item", "ts")) \
         .add_hist_seq("name", ["item"], "ts", 1, 4)
     assert tbl.size() == 8
     assert tbl.df.filter(col("name") == "alice").count() == 2
     assert tbl.df.filter("name like '%jack'").count() == 6
     assert "item_hist_seq" in tbl.df.columns
コード例 #2
0
 def test_sample(self):
     spark = OrcaContext.get_spark_session()
     df = spark.range(1000)
     feature_tbl = FeatureTable(df)
     total_line_1 = feature_tbl.size()
     feature_tbl2 = feature_tbl.sample(0.5)
     total_line_2 = feature_tbl2.size()
     assert int(total_line_1/2) - 100 < total_line_2 < int(total_line_1/2) + 100, \
         "the number of rows should be half"
     total_distinct_line = feature_tbl2.distinct().size()
     assert total_line_2 == total_distinct_line, "all rows should be distinct"
コード例 #3
0
 def test_add_negative_items(self):
     spark = OrcaContext.get_spark_session()
     data = [("jack", 1, "2019-07-01 12:01:19.000"),
             ("jack", 2, "2019-08-01 12:01:19.000"),
             ("jack", 3, "2019-09-01 12:01:19.000"),
             ("alice", 4, "2019-09-01 12:01:19.000"),
             ("alice", 5, "2019-10-01 12:01:19.000"),
             ("alice", 6, "2019-11-01 12:01:19.000")]
     schema = StructType([
         StructField("name", StringType(), True),
         StructField("item", IntegerType(), True),
         StructField("time", StringType(), True)
     ])
     df = spark.createDataFrame(data=data, schema=schema)
     tbl = FeatureTable(df).add_negative_samples(10)
     dft = tbl.df
     assert tbl.size() == 12
     assert dft.filter("label == 1").count() == 6
     assert dft.filter("label == 0").count() == 6
コード例 #4
0
    parser.add_option("--review", dest="review_file")
    parser.add_option("--output", dest="output")
    (options, args) = parser.parse_args(sys.argv)
    begin = time.time()
    sc = init_orca_context("local")
    spark = OrcaContext.get_spark_session()

    # read review datavi run.sh
    transaction_df = spark.read.json(options.review_file).select(
        ['reviewerID', 'asin', 'unixReviewTime']) \
        .withColumnRenamed('reviewerID', 'user') \
        .withColumnRenamed('asin', 'item') \
        .withColumnRenamed('unixReviewTime', 'time')\
        .dropna("any").persist(storageLevel=StorageLevel.DISK_ONLY)
    transaction_tbl = FeatureTable(transaction_df)
    print("review_tbl, ", transaction_tbl.size())

    # read meta data
    def get_category(x):
        cat = x[0][-1] if x[0][-1] is not None else "default"
        return cat.strip().lower()
    spark.udf.register("get_category", get_category, StringType())
    item_df = spark.read.json(options.meta_file).select(['asin', 'categories'])\
        .dropna(subset=['asin', 'categories']) \
        .selectExpr("*", "get_category(categories) as category") \
        .withColumnRenamed("asin", "item").drop("categories").distinct()\
        .persist(storageLevel=StorageLevel.DISK_ONLY)
    item_tbl = FeatureTable(item_df)

    print("item_tbl, ", item_tbl.size())