예제 #1
0
파일: LR.py 프로젝트: angeliababy/text_LTR
    from pyspark.ml.linalg import Vectors

    return Vectors.dense(row[0]), Vectors.dense(row[1]), row[2]


columns = ['article_feature', 'user_feature', 'clicked']
train = train.map(list_to_vector).toDF(columns)

# 3.训练数据特征列和标签列
from pyspark.ml.feature import VectorAssembler

# 文章特征和用户特征合并为一列
train = VectorAssembler().setInputCols(
    columns[0:2]).setOutputCol('features').transform(train)
print("训练集", train.take(10))
train.coalesce(2).write.parquet(path='datas/train/', mode='overwrite')

# 二、模型训练部分
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression()
model = lr.setLabelCol("clicked").setFeaturesCol("features").fit(train)
model.write().overwrite().save("models/lr.obj")

# 三、模型预测部分,此处用的是训练集
from pyspark.ml.classification import LogisticRegressionModel

online_model = LogisticRegressionModel.load("models/lr.obj")
sort_res = online_model.transform(train)
print("在线预测结果", sort_res.take(10))
예제 #2
0
article_profile = article_profile.map(article_profile_to_feature).toDF(['article_id', 'channel_id', 'weights'])
# print("文章权重", article_profile.take(10))

# 二、获取文章向量
article_vector = spark.sparkContext.textFile(r'data\backup\article.db\article_vector')
article_vector = article_vector.map(lambda line: line.split('\x01')).toDF(['article_id', 'channel_id', 'articlevector'])
# print("文章向量", article_vector.take(10))

article_feature = article_profile.join(article_vector, on=['article_id'], how='inner')


# 转化为Vector
def feature_to_vector(row):
    from pyspark.ml.linalg import Vectors

    return row[0], int(row[1]), Vectors.dense(row[2]), Vectors.dense(row[4].split('\x02'))


article_feature = article_feature.rdd.map(feature_to_vector).toDF(['article_id', 'channel_id', 'weights', 'articlevector'])

# 三、将所有文章特征合并成一列
from pyspark.ml.feature import VectorAssembler

columns = ['article_id', 'channel_id', 'weights', 'articlevector']
article_feature = VectorAssembler().setInputCols(columns[1:4]).setOutputCol("article_features").transform(article_feature).drop('channel_id','weights','articlevector')
print("文章特征", article_feature.take(10))
# 文章特征 [Row(article_id='104454', article_features=DenseVector([15.0, 3.2094, 3.2039, 2.9953, 2.0482, 1.9617, 1.6548, 1.3788, 1.3693, 1.2726, 1.0749, 0.0779, -0.0466, -0.1271, -0.0513, -0.0978, 0.2169, 0.1831, -0.0239, -0.0674, 0.0952, 0.0276, 0.2655, -0.1464, -0.2151, 0.0582, 0.1312, -0.0142, -0.127, 0.1623, 0.0151, -0.0363, 0.0072, -0.057, -0.1306, -0.1808, -0.1364, 0.1409, 0.1924, 0.1071, 0.0211, -0.0334, 0.0638, -0.0415, -0.1193, -0.0863, -0.0558, -0.132, 0.0259, -0.0743, 0.1589, 0.0283, 0.1039, 0.0538, 0.2827, -0.172, 0.0103, -0.0252, -0.0788, 0.0087, 0.1353, -0.1259, -0.2982, 0.1853, 0.1347, 0.1359, 0.0385, -0.1453, 0.1341, -0.0524, -0.0562, -0.3002, 0.0933, 0.1789, -0.04, 0.0692, 0.0079, 0.3347, -0.1194, -0.0793, -0.1307, 0.0931, 0.1296, -0.2999, -0.1489, -0.155, 0.4116, -0.0542, 0.1598, 0.1106, -0.0151, -0.0186, 0.1038, 0.0456, 0.1378, 0.2447, -0.1183, -0.1822, 0.0063, 0.0177, 0.0143, 0.2606, -0.1036, -0.0728, 0.0099, 0.1133, 0.0096, 0.0549, 0.1118, -0.0809, -0.1995])), Row(article_id='11078', article_features=DenseVector([19.0, 4.3524, 2.6313, 1.8421, 1.7456, 1.708, 1.3605, 1.2479, 1.2043, 1.1133, 0.8934, 0.0921, 0.1117, -0.2751, 0.0562, 0.2314, -0.2011, -0.1391, -0.0455, 0.0326, -0.3225, 0.0686, 0.1775, -0.1634, 0.0508, -0.0243, -0.0589, -0.1448, 0.0018, -0.0765, 0.2269, 0.1435, 0.0894, 0.205, -0.1244, -0.4143, -0.2649, 0.2978, 0.1372, -0.2373, 0.0868, -0.1164, 0.1571, 0.0712, -0.4209, -0.1384, 0.059, -0.2506, 0.0653, -0.1703, -0.0296, 0.2106, -0.1498, -0.0433, -0.074, -0.1136, 0.0374, 0.1584, -0.4794, -0.0552, 0.0386, 0.1024, 0.0608, 0.1539, -0.1179, -0.2023, -0.145, -0.0226, 0.1391, -0.141, 0.091, -0.1656, 0.0819, 0.2499, -0.176, 0.0247, 0.0642, -0.2663, -0.1592, -0.0568, 0.0102, -0.243, -0.1601, 0.1592, -0.0506, 0.1318, -0.0145, -0.0471, -0.2726, 0.106, -0.2138, -0.0267, 0.1777, -0.1218, 0.1902, 0.0832, 0.108, 0.3634, 0.1479, -0.1266, 0.2028, 0.1931, -0.0179, -0.1039, -0.2265, -0.057, -0.478, -0.116, -0.3476, 0.2256, 0.0315])), Row(article_id='112261', article_features=DenseVector([19.0, 6.9986, 6.7458, 4.3174, 4.0892, 3.8024, 3.7027, 3.5808, 3.1863, 3.1441, 2.8555, 0.2672, -0.015, -0.5429, 0.2897, 0.5796, 0.219, 0.0784, -0.0887, 0.426, -0.4749, 0.0278, -0.1603, -0.5976, -0.092, -0.0429, -0.3129, -0.0699, -0.6033, -0.8889, 0.3618, -0.2609, 0.4682, 0.702, 0.4104, -0.5749, 0.2777, -0.0784, 0.386, 0.4163, 0.5496, -1.0026, 0.0563, 0.1473, 0.2207, 0.4276, -0.0062, 0.0977, 0.3181, -0.5622, 0.0398, -0.407, -0.1554, -0.367, -0.6229, 0.4178, -0.5088, 0.1394, 0.4287, -0.2425, -0.1597, -0.2231, 0.2254, 0.3132, 0.4591, -0.5523, 0.0653, -0.3522, -0.3025, 0.8189, -0.1365, 0.8583, -0.0341, 0.5411, -0.0189, -0.7076, 0.4366, 0.3908, 0.2022, -0.2131, -0.3759, 0.2754, -0.1044, 0.8178, 0.0268, 0.0513, 0.3592, -0.4084, 0.1466, -0.1035, 0.9015, -0.0975, -0.3203, 0.076, -0.0188, -0.3406, 0.0043, 0.7749, 0.1, -0.1054, 0.9776, -0.2702, -0.4802, -0.8697, -0.0798, -0.1149, -0.154, 0.1037, 0.4164, 0.1837, -0.3219])), Row(article_id='113417', article_features=DenseVector([19.0, 2.7855, 2.4906, 2.3413, 2.2181, 2.0936, 2.0288, 1.9961, 1.8992, 1.779, 1.5828, -0.0298, 0.1844, -0.1527, 0.2177, -0.0928, 0.1654, 0.0017, -0.1639, -0.0378, -0.058, 0.2614, 0.1186, -0.0639, -0.1151, -0.1751, -0.2461, 0.1065, 0.099, 0.0717, -0.0506, -0.1597, -0.1328, 0.1314, -0.1697, 0.3646, 0.1417, 0.2085, 0.2381, -0.0275, -0.1403, -0.1953, -0.3581, -0.0099, -0.2884, -0.0762, 0.1037, -0.001, -0.0134, -0.2092, 0.2685, -0.2462, 0.3999, -0.0012, 0.156, -0.2364, 0.3707, 0.1092, 0.1387, 0.0823, 0.0058, 0.0016, -0.0677, 0.252, 0.1782, 0.0018, -0.1861, 0.1731, -0.1112, 0.3227, -0.3457, -0.1461, -0.0333, -0.0747, -0.018, 0.2252, 0.0057, 0.0683, 0.2159, -0.0124, 0.1717, -0.1419, 0.1799, -0.2844, -0.1427, -0.1597, -0.337, 0.1603, 0.1057, -0.0095, -0.061, -0.1318, -0.1161, 0.2994, 0.2736, 0.1247, -0.2224, -0.1255, -0.3158, -0.0457, -0.2657, -0.3013, -0.1367, 0.145, -0.0659, -0.138, 0.0018, -0.1144, -0.092, -0.004, -0.2964])), Row(article_id='115722', article_features=DenseVector([19.0, 5.6859, 3.7446, 3.5655, 2.3104, 2.3089, 1.8003, 1.5114, 1.4648, 1.3759, 1.3268, 0.0901, -0.0201, -0.4115, 0.2405, 0.5276, -0.1327, -0.4406, -0.1985, -0.1297, 0.0191, 0.1782, -0.2119, -0.3124, -0.001, -0.1857, -0.0555, 0.1649, 0.0812, -0.1855, 0.3606, -0.2043, 0.2093, 0.0622, 0.2577, -0.0443, 0.001, -0.3585, 0.3761, -0.4314, -0.1804, -0.1869, 0.2003, -0.1028, -0.0559, -0.6038, 0.0253, -0.078, -0.1018, -0.025, 0.0137, 0.0066, -0.2711, 0.0667, -0.3556, -0.013, -0.1323, 0.4523, 0.0631, -0.1825, -0.1171, 0.0118, -0.266, 0.1047, -0.1262, -0.1043, -0.042, -0.2069, -0.142, 0.2523, 0.1945, 0.5276, -0.1675, 0.2908, -0.3185, -0.0355, -0.2327, 0.0476, 0.0345, 0.2535, -0.0535, -0.1166, -0.1315, -0.0552, -0.2849, -0.0781, 0.0929, -0.308, 0.1491, 0.055, 0.4864, -0.031, -0.1496, -0.1269, -0.0451, 0.0547, -0.2267, 0.293, 0.115, 0.1512, 0.5928, 0.2199, 0.0512, -0.1089, 0.2852, 0.6087, -0.1846, -0.2396, -0.1107, 0.3475, -0.1342])), Row(article_id='117630', article_features=DenseVector([18.0, 3.2094, 2.8729, 2.3898, 2.3177, 1.8147, 1.7998, 1.7648, 1.6284, 1.5127, 1.4225, -0.0538, -0.0768, 0.1833, 0.0466, -0.1089, 0.0345, -0.1708, -0.0353, 0.1568, -0.2158, -0.1089, -0.0345, 0.0828, 0.0192, 0.2037, -0.0531, -0.0118, 0.1785, -0.0062, -0.0975, -0.1063, 0.0481, -0.0017, -0.0836, -0.0247, -0.0351, -0.0074, 0.0669, -0.0032, -0.1687, -0.2153, 0.0899, 0.2705, 0.0298, 0.0909, -0.1071, 0.0713, -0.1074, -0.0697, 0.1991, 0.2544, -0.1318, 0.1003, 0.1545, -0.1852, -0.2174, 0.0665, 0.0743, -0.1158, 0.0718, -0.0107, 0.019, -0.0768, -0.1513, 0.0891, 0.0016, -0.3053, 0.0644, 0.0915, 0.0967, -0.0474, -0.0798, 0.0307, -0.25, 0.0339, 0.0657, 0.0784, -0.2457, 0.0702, 0.1407, -0.1035, -0.1085, -0.0559, -0.0471, 0.0617, -0.0366, -0.0106, -0.0236, 0.028, -0.0997, -0.0598, 0.0276, -0.2518, 0.057, -0.1273, 0.0526, -0.1048, -0.119, 0.0663, 0.1958, 0.4176, 0.0814, 0.1912, 0.0623, -0.0552, -0.0778, -0.0528, -0.1847, -0.024, 0.2985])), Row(article_id='119461', article_features=DenseVector([22.0, 4.7052, 4.0736, 3.7461, 3.7127, 2.6155, 2.4923, 2.3657, 2.3224, 2.1035, 2.0412, 0.2313, 0.0846, 0.1077, -0.0436, -0.0696, 0.1053, 0.0836, 0.086, 0.3162, -0.1226, 0.294, -0.0902, 0.1681, -0.095, -0.0726, -0.1178, 0.011, 0.0328, -0.0711, -0.0616, 0.0169, -0.038, 0.5369, 0.1958, -0.3401, 0.1829, -0.1733, 0.2435, 0.1673, -0.0613, -0.0666, -0.0395, 0.114, 0.0216, -0.0507, 0.1734, -0.0618, -0.0962, 0.0748, -0.0268, 0.0123, -0.0055, 0.1346, 0.0706, 0.1962, 0.0973, -0.0784, -0.0162, 0.1298, -0.195, 0.0533, -0.2352, 0.3798, -0.1018, -0.0351, -0.17, 0.0935, 0.2375, 0.0221, -0.0784, -0.2106, 0.1597, 0.1108, -0.1046, 0.0189, -0.0967, 0.1163, 0.1526, 0.2088, -0.3541, -0.009, -0.0031, -0.3506, -0.0292, -0.3302, -0.0121, 0.0246, 0.021, -0.0466, 0.011, 0.0307, -0.034, 0.1512, 0.1588, -0.0916, -0.1144, 0.0681, 0.2597, 0.0077, 0.1922, 0.1088, -0.0556, 0.3096, -0.0011, -0.0306, -0.0316, 0.1571, 0.0808, 0.0036, 0.2122])), Row(article_id='119625', article_features=DenseVector([14.0, 6.6868, 2.1803, 1.6385, 1.5585, 1.1332, 1.1292, 1.0648, 0.9441, 0.9112, 0.847, 0.2274, 0.0027, 0.1581, 0.2004, 0.1147, 0.1, 0.0026, 0.0451, 0.0752, -0.1187, -0.1038, 0.2343, -0.0648, -0.0299, -0.0367, 0.1381, -0.0247, 0.0181, 0.1119, 0.2383, -0.3084, -0.1319, 0.0567, -0.1086, 0.0871, -0.1053, 0.1258, -0.1489, 0.1292, -0.345, 0.0593, -0.0612, -0.1155, -0.2328, -0.1287, 0.1569, -0.069, 0.1631, 0.0166, 0.0122, -0.0538, -0.1723, -0.025, 0.1263, 0.161, -0.2254, 0.0802, 0.0076, -0.0287, -0.0893, 0.1625, -0.0175, 0.094, -0.0602, -0.0086, -0.0022, -0.041, -0.0488, -0.0633, 0.1452, -0.2495, -0.0223, 0.0455, 0.0092, -0.1723, 0.0727, -0.0169, -0.2192, -0.1675, 0.1552, -0.0563, -0.1085, -0.1911, 0.1988, -0.2169, -0.0669, 0.3274, 0.1026, 0.1535, 0.0335, 0.1547, 0.0512, -0.1815, -0.1877, -0.3684, 0.0793, 0.0444, 0.2351, -0.2766, -0.0722, -0.0111, 0.1854, 0.0716, 0.0789, -0.1626, 0.0555, -0.0326, -0.2006, 0.0213, 0.0792])), Row(article_id='120478', article_features=DenseVector([14.0, 4.883, 3.0421, 2.9081, 2.6926, 2.6628, 2.4354, 2.2788, 2.2385, 2.2055, 2.1157, 0.1775, 0.0412, 0.2481, -0.0082, 0.1446, -0.1056, -0.1029, 0.1595, 0.2366, -0.0192, 0.0454, -0.1919, -0.0384, 0.1176, -0.3365, -0.1762, 0.0891, 0.1654, -0.0951, -0.0456, 0.0075, -0.0786, 0.0896, -0.0216, 0.1773, 0.1729, 0.0378, 0.0203, 0.19, -0.0377, -0.0762, 0.0399, 0.1559, 0.1463, 0.1298, 0.0003, 0.0299, 0.0338, 0.1274, -0.0216, -0.2186, -0.0783, -0.0049, -0.0058, 0.0427, -0.1087, 0.1767, 0.1692, 0.0636, -0.1144, 0.0077, 0.0367, -0.1269, 0.1001, -0.0116, -0.0434, -0.1044, -0.0063, -0.1787, 0.115, 0.072, 0.0085, -0.0191, -0.1266, -0.0386, 0.0234, -0.0595, -0.1128, 0.0137, 0.0861, -0.0041, 0.0365, -0.1963, 0.1349, -0.0227, 0.1432, -0.0216, -0.0033, -0.0898, -0.0825, 0.137, 0.0864, -0.0345, -0.2399, -0.0038, 0.0524, -0.0676, 0.1706, -0.0474, -0.0872, -0.0323, -0.0652, 0.2064, 0.1674, 0.0156, -0.0713, -0.1472, -0.0351, -0.0888, 0.288])), Row(article_id='121867', article_features=DenseVector([9.0, 6.5255, 2.722, 1.0556, 0.9091, 0.8219, 0.5892, 0.5599, 0.5033, 0.5023, 0.4794, 0.0723, 0.1117, -0.167, -0.1872, 0.0355, -0.0896, -0.2632, -0.1381, -0.0657, 0.1413, 0.107, -0.2137, 0.1305, 0.1737, 0.048, 0.1009, 0.3092, 0.0021, 0.0745, 0.144, 0.0099, -0.1198, -0.0918, 0.0268, -0.1822, 0.3223, 0.1144, -0.0515, 0.0231, 0.036, -0.0902, -0.2924, -0.0907, 0.135, -0.0776, 0.0826, 0.1058, -0.0482, 0.0491, 0.4293, 0.1711, -0.2485, -0.1281, -0.1473, -0.0147, 0.1557, 0.1302, -0.0119, 0.0774, -0.2075, 0.1212, 0.0114, 0.1829, 0.0471, 0.0695, 0.2132, -0.1485, 0.0977, -0.1875, 0.1439, 0.1427, -0.0237, -0.1829, 0.1583, 0.039, 0.1121, 0.1051, -0.0652, 0.0332, 0.1478, -0.0499, -0.2293, -0.0226, -0.0338, -0.1036, 0.1063, -0.0885, -0.1743, -0.0995, -0.1402, -0.0352, 0.0531, 0.054, 0.1145, -0.154, 0.0186, -0.1418, -0.044, 0.0371, -0.1269, -0.0838, 0.0357, 0.0751, 0.0195, -0.0793, -0.0792, -0.0121, 0.012, 0.0552, 0.0009]))]

# article_feature.rdd.repartition(1).saveAsTextFile("article_feature/")
article_feature.coalesce(2).write.parquet(path='datas/article_features/',mode='overwrite')