예제 #1
0
 def test_gen_string_idx(self):
     file_path = os.path.join(self.resource_path,
                              "friesian/feature/parquet/data1.parquet")
     feature_tbl = FeatureTable.read_parquet(file_path)
     string_idx_list = feature_tbl.gen_string_idx(["col_4", "col_5"],
                                                  freq_limit="1")
     assert string_idx_list[0].count() == 3, "col_4 should have 3 indices"
     assert string_idx_list[1].count() == 2, "col_5 should have 2 indices"
     with tempfile.TemporaryDirectory() as local_path:
         for str_idx in string_idx_list:
             str_idx.write_parquet(local_path)
             str_idx_log = str_idx.log(["id"])
             assert str_idx.df.filter(
                 "id == 1").count() == 1, "id in str_idx should = 1"
             assert str_idx_log.df.filter("id == 1").count() == 0, "id in str_idx_log should " \
                                                                   "!= 1"
         assert os.path.isdir(local_path + "/col_4.parquet")
         assert os.path.isdir(local_path + "/col_5.parquet")
         new_col_4_idx = StringIndex.read_parquet(local_path +
                                                  "/col_4.parquet")
         assert "col_4" in new_col_4_idx.df.columns, "col_4 should be a column of new_col_4_idx"
         with self.assertRaises(Exception) as context:
             StringIndex.read_parquet(local_path + "/col_5.parquet",
                                      "col_4")
         self.assertTrue('col_4 should be a column of the DataFrame' in str(
             context.exception))
예제 #2
0
 def test_create_from_dict(self):
     indices = {'a': 1, 'b': 2, 'c': 3}
     col_name = 'letter'
     tbl = StringIndex.from_dict(indices, col_name)
     assert 'id' in tbl.df.columns, "id should be one column in the stringindex"
     assert 'letter' in tbl.df.columns, "letter should be one column in the stringindex"
     assert tbl.size() == 3, "the StringIndex should have three rows"
     with self.assertRaises(Exception) as context:
         StringIndex.from_dict(indices, None)
     self.assertTrue("col_name should be str, but get None"
                     in str(context.exception))
     with self.assertRaises(Exception) as context:
         StringIndex.from_dict(indices, 12)
     self.assertTrue("col_name should be str, but get int"
                     in str(context.exception))
     with self.assertRaises(Exception) as context:
         StringIndex.from_dict([indices], col_name)
     self.assertTrue("indices should be dict, but get list"
                     in str(context.exception))
        return cat.strip().lower()
    spark.udf.register("get_category", get_category, StringType())
    item_df = spark.read.json(options.meta_file).select(['asin', 'categories'])\
        .dropna(subset=['asin', 'categories']) \
        .selectExpr("*", "get_category(categories) as category") \
        .withColumnRenamed("asin", "item").drop("categories").distinct()\
        .persist(storageLevel=StorageLevel.DISK_ONLY)
    item_tbl = FeatureTable(item_df)

    print("item_tbl, ", item_tbl.size())

    item_category_indices = item_tbl.gen_string_idx(["item", "category"], 1)
    cat_default = item_category_indices[1].df.filter("category == 'default'").collect()
    default_cat = cat_default[0][1] if cat_default else item_category_indices[1].size()
    new_row = spark.createDataFrame([("default", int(default_cat))], ["category", "id"])
    category_index = StringIndex(item_category_indices[1].df.union(new_row).distinct()
                                 .withColumn("id", col("id").cast("Integer")), "category")
    item_size = item_category_indices[0].size()

    user_index = transaction_tbl.gen_string_idx(['user'], 1)
    get_label = udf(lambda x: [float(x), 1 - float(x)], ArrayType(FloatType()))
    item2cat = item_tbl\
        .encode_string(["item", "category"], [item_category_indices[0], category_index])\
        .distinct()

    full_tbl = transaction_tbl\
        .encode_string(['user', 'item'], [user_index[0], item_category_indices[0]])\
        .add_hist_seq(user_col="user", cols=['item'],
                      sort_col='time', min_len=1, max_len=100)\
        .add_neg_hist_seq(item_size, 'item_hist_seq', neg_num=5) \
        .add_negative_samples(item_size, item_col='item', neg_num=1)\
        .join(item2cat, "item")\