예제 #1
0
def main(spark, train_data_file, rank_val, reg, alpha_val, user_indexer_model,
         item_indexer_model, model_file):
    '''
    Parameters
    ----------
    spark : SparkSession object
    data_file : string, path to the parquet file to load
    model_file : string, path to store the serialized model file
    '''

    # Load the parquet file
    train = spark.read.parquet(train_data_file)
    #val = spark.read.parquet(val_data_file)

    #transform data
    indexer_user = StringIndexer(inputCol="user_id",
                                 outputCol="user",
                                 handleInvalid="skip").fit(train)
    indexer_item = StringIndexer(inputCol="track_id",
                                 outputCol="item",
                                 handleInvalid="skip").fit(train)
    als = ALS(userCol='user',
              itemCol='item',
              implicitPrefs=True,
              ratingCol='count',
              rank=rank_val,
              regParam=reg,
              alpha=alpha_val)

    pipeline = Pipeline(stages=[indexer_user, indexer_item, als])
    train = indexer_user.transform(train)
    train = indexer_item.transform(train)
    model = als.fit(train)
    indexer_user.save(user_indexer_model)
    indexer_item.save(item_indexer_model)
    model.save(model_file)
예제 #2
0
def oneHot(df, base_col_name, col_name):
    from pyspark.sql import SparkSession
    from pyspark import SparkContext, SparkConf
    from pyspark.sql import SparkSession
    import os
    import time

    #os.environ['SPARK_HOME'] = '/root/spark-2.1.1-bin'

    sparkConf = SparkConf() \
        .setAppName('pyspark rentmodel') \
        .setMaster('local[*]')
    sc = SparkContext.getOrCreate(sparkConf)

    sc.setLogLevel('WARN')

    spark = SparkSession(sparkContext=sc)

    df = df.select(base_col_name, col_name)
    df = df.filter(df[base_col_name].isNotNull())
    # StringIndexer'handleInvalid of python'version no have 'keep',so it can't process null value
    null_col_name = col_name + '_null'
    df = df.na.fill(null_col_name, col_name)
    df_NULL = df.filter(df[col_name] == 'NULL')

    df = df.filter(df[col_name].isNotNull())
    df = df.filter(df[col_name] != '')
    print('one-hot=======', col_name, df.count())

    temp_path = '/data/20180621/ALL_58_beijing_save_models/'

    if df_NULL.count() > 0:

        def udf_NULL(s):
            return null_col_name

        udf_transf = udf(udf_NULL)

        df_NULL = df_NULL.select('*',
                                 udf_transf(col_name).alias('tmp_col_name'))
        df_NULL = df_NULL.na.fill(null_col_name, 'tmp_col_name')
        df_NULL = df_NULL.drop(col_name)
        df_NULL = df_NULL.withColumnRenamed('tmp_col_name', col_name)

        df_no_NULL = df.filter(df[col_name] != 'NULL')
        df_no_NULL = df_no_NULL.withColumn('tmp_col_name', df[col_name])
        df_no_NULL = df_no_NULL.drop(col_name)
        df_no_NULL = df_no_NULL.withColumnRenamed('tmp_col_name', col_name)
        df = df_no_NULL.union(df_NULL)
        del df_no_NULL

    index_name = col_name + 'Index'
    vector_name = col_name + 'Vec'
    """
        StringIndexer可以设置handleInvalid='skip',但是不可以设置handleInvalid='keep'.
        设置这个会删除需要跳过的这一行,这样会导致用户体验差,因为用户输入
        一条数据,就直接给删了,什么都没有。因此暂不设置,新数据输入时,如果没有,
        可以在已经有的字符串中随机选择一个来替换没有的这个新字符串.
    """
    stringIndexer = StringIndexer(inputCol=col_name, outputCol=index_name)
    model = stringIndexer.fit(df)
    indexed = model.transform(df)
    encoder = OneHotEncoder(dropLast=False,
                            inputCol=index_name,
                            outputCol=vector_name)
    encoded = encoder.transform(indexed)

    #save
    stringIndexer.save(temp_path + 'stringIndexer' + col_name)
    model.save(temp_path + 'stringIndexer_model' + col_name)

    # StringIndexer(inputCol=col_name, outputCol=index_name)
    # onehotEncoderPath = temp_path + col_name
    # loadedEncoder = OneHotEncoder.load(onehotEncoderPath)
    # loadedEncoder.setParams(inputCol=index_name, outputCol=vector_name)
    # encoded = loadedEncoder.transform(df)
    # encoded.show()

    onehotEncoderPath = temp_path + col_name + '_new'
    encoder.save(onehotEncoderPath)

    sub_encoded = encoded.select(base_col_name, vector_name)

    return sub_encoded