/
main.py
286 lines (247 loc) · 10.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
微博情感分析,LDA主题聚类
"""
import os
import gc
import re
# import time
import logging
from argparse import ArgumentParser
from multiprocessing import Pool, Manager
from numpy import zeros, concatenate
from gensim.utils import grouper
from gensim.models import LdaModel, LdaMulticore
from gensim.corpora import Dictionary
from pandas import DataFrame, Series, read_excel, concat
from util import STOP_WORDS, output, generate_batch, to_excel, cache_path, dump_cache, load_cache
from workers import skep_consumer, skep_producer, ltp_tokenzier
# from ldamulticore import LdaMulticore
FORWARD_SPLIT = re.compile(r"//@[^/::]+[::]")
FORWARD_CONTENT = re.compile(r"//@[^/::]+[::][^/]+")
URL_REGEX = re.compile(r"http[s]?://[a-zA-Z0-9.?/&=:]*")
_ARG_PARSER = ArgumentParser(description="我的实验,需要指定配置文件")
_ARG_PARSER.add_argument('--name', '-n',
type=str,
default='clean',
help='configuration file path.')
_ARG_PARSER.add_argument('--ltpIDS', '-l',
type=str,
default='6,7',
help='gpu ids, like: 1,2,3')
_ARG_PARSER.add_argument('--skepIDS', '-s',
type=str,
default='1,2,3,4,5',
help='gpu ids, like: 1,2,3')
_ARG_PARSER.add_argument('--range', '-r',
type=str,
default='30,60',
help='话题数搜索范围,左闭右开')
_ARG_PARSER.add_argument('--passes', '-p',
type=int,
default=10,
help='数据集迭代次数, epoch')
_ARG_PARSER.add_argument('--iterations', '-it',
type=int,
default=50,
help='推断时最大迭代次数')
_ARG_PARSER.add_argument('--keywords_num', '-k',
type=int,
default=50,
help='存储关键词数量')
_ARG_PARSER.add_argument('--pool_size', '-ps',
type=int,
default=14,
help='进程池大小, 建议物理核个数')
_ARG_PARSER.add_argument('--debug', '-d', type=bool, default=False)
_ARGS = _ARG_PARSER.parse_args()
os.environ['OMP_NUM_THREADS'] = '1'
logging.basicConfig(format='[%(asctime)s - %(process)s - %(levelname)s] : %(message)s', level=logging.DEBUG)
def read(path) -> DataFrame:
def _clean(row):
text = URL_REGEX.sub('', row.contents)
if row.is_forward and '//@' in text:
# 如果是转发的且格式正确
if text.startswith('//@'):
# 如果单纯转发,则内容设置为最原始微博的内容
try:
text = FORWARD_CONTENT.findall(text)[-1]
i = FORWARD_SPLIT.match(text).regs[0][1]
text = text[i:]
except IndexError:
text = text.replace('//@', '') # TODO 可以用weibo的API处理
else:
# 否则截取新内容
text = text[:text.find('//@')]
return text
temp_name = os.path.basename(path).replace('.xlsx', '')
if os.path.isfile(cache_path(temp_name)):
data, texts = load_cache(temp_name)
else:
output(f"===> Reading from <{path}>.")
data: DataFrame = read_excel(path) # .iloc[:280]
# 只保留想要的4列,并去除空值,截取日期
data = data[['contents', 'time', 'id', 'is_forward']].dropna().reset_index()
data['date'] = data['time'].apply(lambda s: s[:10])
data['contents'] = data['contents'].astype(str)
# 预处理文本
texts = data.apply(_clean, axis=1).to_list()
dump_cache((data, texts), temp_name)
output(f"===> got {len(data)} rows from <{path}>.")
# 解析GPU ID
ltp_ids = [i.strip() for i in _ARGS.ltpIDS.split(',')]
skep_ids = [i.strip() for i in _ARGS.skepIDS.split(',')]
# 初始化进程池,管理器,数据队列
pool = Pool(1 + len(ltp_ids) + len(skep_ids)) # 分别分词、获取skep输入、skep运算
manager = Manager()
feqture_queue = manager.Queue(16 * len(skep_ids))
result_queue = manager.Queue(16 * len(skep_ids))
# 异步任务启动
pool.apply_async(skep_producer, (feqture_queue, texts, 16, len(skep_ids)))
tokens = dict()
for i, (s, p) in zip(ltp_ids, generate_batch(texts, len(texts) // len(ltp_ids) + 1)):
tokens[(s.start, s.stop)] = pool.apply_async(ltp_tokenzier, (p, 192, i))
for i in skep_ids:
pool.apply_async(skep_consumer, (feqture_queue, result_queue, i))
# 接收结果
scores, counter = zeros(len(texts)), 1
while True:
_slice, array = result_queue.get()
# print(_slice)
if array is None:
if counter < len(skep_ids):
counter += 1
else:
break
else:
scores[_slice] = array
data['tokens'] = None
for s, t in tokens.items():
data['tokens'].update(Series(t.get(), range(*s)))
data['sentiment_score'] = scores
pool.close()
pool.join()
return data[['date', 'tokens', 'id', 'sentiment_score']]
def save_and_inference(model: LdaModel, corpus, num_topics, chunksize=0):
path = f"./dev/model/{_ARGS.name}_{num_topics}.pkl"
try:
model.save(path)
output(f"model saved at <{path}>")
if chunksize > 0:
gammas = [model.inference(chunk)[0] for chunk in grouper(corpus, chunksize)]
gamma = concatenate(gammas)
else:
gamma, _ = model.inference(corpus)
except RuntimeError as e:
logging.error(f"PID: {os.getpid()}, num_topics: {num_topics} error")
print(e)
output(f"num_topics {num_topics} inference compete.")
return gamma.argmax(axis=1)
def eval_and_write(data, num_topics, documents, dictionary, corpus, model, ids):
print('\nnum_topics: ', num_topics)
# print('Model perplexity: ', model.log_perplexity(corpus)) # 这个没做归一
top_topics = model.top_topics(corpus, documents, dictionary,
coherence='c_v', topn=_ARGS.keywords_num,
processes=_ARGS.pool_size)
scores = Series([t[1] for t in top_topics])
print('Coherence Score: ', scores.mean())
# 得到关键词词频
topics_info = list()
for _topic in top_topics:
tokens = [(t[1], t[0], dictionary.cfs[dictionary.token2id[t[1]]]) for t in _topic[0]]
topics_info.append((tokens, _topic[1]))
to_excel(topics_info, data, ids, _ARGS.name)
def get_model(corpus, num_topics, kwargs):
output(f"running num_topics: {num_topics}.")
try:
model = LdaModel(corpus, num_topics, **kwargs)
topic_ids = save_and_inference(model, corpus, num_topics)
except RuntimeError as e:
logging.error(f"PID: {os.getpid()}, num_topics: {num_topics} error")
print(e)
return model, topic_ids
def pipline(data: DataFrame):
if os.path.isfile(cache_path('run/' + _ARGS.name)):
corpus, dictionary, documents = load_cache('run/' + _ARGS.name)
elif data:
documents = data['tokens'].to_list()
# Create a dictionary representation of the documents.
dictionary = Dictionary(documents)
# Filter out words that occur less than 20 documents, or more than 50% of the documents.
dictionary.filter_extremes(no_below=20, no_above=0.5)
# 去停用词
bad_ids = [dictionary.token2id[t] for t in STOP_WORDS if t in dictionary.token2id]
dictionary.filter_tokens(bad_ids=bad_ids)
# Bag-of-words representation of the documents.
corpus = [dictionary.doc2bow(doc) for doc in documents]
dump_cache((corpus, dictionary, documents), 'run/' + _ARGS.name)
else:
raise ValueError('cache不存在且未传入data')
_ = dictionary[0] # This is only to "load" the dictionary.
output('Number of unique tokens: ', len(dictionary))
output('Number of documents: ', len(corpus))
# test = get_model(6, corpus, dictionary.id2token)
topic_range = tuple(int(s.strip()) for s in _ARGS.range.split(','))
kwargs = dict(
id2word=dictionary.id2token, chunksize=len(corpus),
passes=_ARGS.passes, alpha='auto', eta='auto', eval_every=1,
iterations=_ARGS.iterations, random_state=123)
if len(corpus) < 1e6: # 并行训练模型
pool = Pool(_ARGS.pool_size)
result_dict = dict()
for k in range(*topic_range):
result_dict[k] = pool.apply_async(get_model, (corpus, k, kwargs))
result_dict = {k: v.get() for k, v in result_dict.items()}
pool.close() # 等子进程执行完毕后关闭进程池
pool.join()
output(f"Searched range{topic_range}")
# 计算一致性的代码自己有多进程,所以只能串行
for k, (model, ids) in result_dict.items():
eval_and_write(data, k, documents, dictionary, corpus, model, ids)
else:
# kwargs['alpha'] = 'symmetric'
kwargs['chunksize'] = len(corpus) // 8 // _ARGS.pool_size + 1
# kwargs['batch'] = True
for k in range(*topic_range, 2): # 大数据就粗点筛
# model = LdaMulticore(corpus, k, workers=_ARGS.pool_size, **kwargs)
model = LdaModel(corpus, k, **kwargs)
ids = save_and_inference(model, corpus, k, kwargs['chunksize'])
# result_dict[k] = (model, ids) # 内存不够用啊,4M句子
eval_and_write(None, k, documents, dictionary, corpus, model, ids)
del model, ids
gc.collect()
output(f"===> {_ARGS.name} compete. \n")
def main():
if os.path.isfile(cache_path(_ARGS.name)):
if _ARGS.name == 'clean':
# c, d = load_cache('run/' + _ARGS.name)
# documents = df['tokens'].to_list()
# dump_cache((c, d, documents), 'run/' + _ARGS.name)
pipline(None)
return
else:
df = load_cache(_ARGS.name)
else:
if _ARGS.name == 'clean':
dfs = list()
for i in range(6):
path = f"./dev/data/clean{i}_covid19.xlsx"
if os.path.isfile(cache_path(f'clean{i}')):
part = load_cache(f'clean{i}')
else:
part = read(path)
dump_cache(part, f'clean{i}')
dfs.append(part)
df = concat(dfs, ignore_index=True)
else:
path = f"./dev/data/{_ARGS.name}_covid19.xlsx"
df = read(path)
dump_cache(df, _ARGS.name)
# logging.disable(level=logging.INFO)
pipline(df)
return
if __name__ == "__main__":
main()
"""
统计每天每主题数量,每天(每主题)感情变化
之前331k要处理7个小时,2张卡;现在20min,7张卡。
"""