# 字词混合Embedding
# 93.5%+

# 处理数据
X, y, classes = load_THUCNews_title_label()
X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    train_size=0.8,
                                                    random_state=7322)

num_classes = len(classes)
# 转化成字id
print("tokenize...")
tokenizer = Tokenizer(mintf=32, cutword=True)
tokenizer.fit_in_parallel(X_train)

# maxlen = find_best_maxlen(X_train, mode="max")
maxlen = 48


def pad(X, maxlen):
    return sequence.pad_sequences(X,
                                  maxlen=maxlen,
                                  dtype="int32",
                                  padding="post",
                                  truncating="post",
                                  value=0)


def create_dataset(X, y, maxlen):
    if shuffle:
        random.shuffle(files)
    for file in files[:limit]:
        with open(file, encoding="utf-8") as fd:
            content = fd.read()
        yield preprocess(content)


file = "word_meta.json"
tokenizer = Tokenizer(mintf, processes)
if os.path.exists(file):
    tokenizer.load(file)
else:
    X = load_sentences(limit=None)
    print("tokenize...")
    tokenizer.fit_in_parallel(X)
    tokenizer.save(file)

words = tokenizer.words
word2id = tokenizer.word2id
id2word = {j: i for i, j in word2id.items()}
vocab_size = len(tokenizer)


def create_subsamples(words, subsample_eps=1e-5):
    # 计算降采样表,用于context
    # 参考tf.keras.preprocessing.sequence.make_sampling_table
    total = len(words)
    subsamples = {}
    for i, j in words.items():
        j = j / total