from model import textCNN_train_test, MAX_SEQ_LEN, in_path from time import time if __name__ == '__main__': print('——————————————load data——————————————') (X_train, y_train), (X_test, y_test) = imdb.load_data() X_all = (list(X_train) + list(X_test))[0:] y_all = (list(y_train) + list(y_test))[0:] print(len(X_all), len(y_all)) imdb_word2idx = imdb.get_word_index() imdb_idx2word = dict((idx, word) for (word, idx) in imdb_word2idx.items()) X_all = [[imdb_idx2word.get(idx - 3, '?') for idx in sen][1:] for sen in X_all] w2vModel = train_W2V(X_all, in_path + 'w2vModel') word2idx, embedMatrix = build_word2idx_embedMatrix( w2vModel) # 制作word2idx和embedMatrix X_all_idx = make_X_train_idx(X_all, word2idx, MAX_SEQ_LEN) y_all_idx = np.array(y_all) # 一定要注意,X_all和y_all必须是np.array()类型,否则报错 X_tra_idx, X_val_idx, y_tra_idx, y_val_idx = train_test_split( X_all_idx, y_all_idx, test_size=0.2, random_state=0, stratify=y_all_idx) y_tra_oneHot = make_y_train_oneHot(y_tra_idx) y_val_oneHot = make_y_train_oneHot(y_val_idx) print('——————————————模型的训练和预测——————————————')
else: print('preprocess_data') X_train, y_train = load_data(in_path + 'yelp-2014-seg-20-20.train.ss') X_test, y_test = load_data(in_path + 'yelp-2014-seg-20-20.test.ss') X_train = [paragraph.split(' <sssss> ') for paragraph in X_train] X_test = [paragraph.split(' <sssss> ') for paragraph in X_test] X_train = [[split_words(sent) for sent in paragraph] for paragraph in X_train] X_test = [[split_words(sent) for sent in paragraph] for paragraph in X_test] W2V_corpus = W2V_corpus_iter(X_train) w2vModel = train_W2V(W2V_corpus, in_path + 'w2vModel') word2idx, embedMatrix = build_word2idx_embedMatrix( w2vModel) # 制作word2idx和embedMatrix X_train_idx = make_X_train_idx(X_train, word2idx, MAX_SENT_NUM, MAX_SENT_LEN) X_test_idx = make_X_train_idx(X_test, word2idx, MAX_SENT_NUM, MAX_SENT_LEN) y_train_oneHot = make_y_train_oneHot(y_train, is_cate_dict=True) y_test_oneHot = make_y_train_oneHot(y_test, is_cate_dict=True) print(len(X_train_idx), len(X_test_idx), len(y_train_oneHot), len(y_test_oneHot)) yelp_2014_data = {}