def main(_): with open("./data/normal_hyperparams.pkl", 'rb') as f: config = pickle.load(f) with open("./data/normal_vocab.pkl", 'rb') as f: vocab_i2c = pickle.load(f) vocab_size = len(vocab_i2c) vocab_c2i = dict(zip(vocab_i2c, range(vocab_size))) with tf.variable_scope('normal'): model = CharRNN(vocab_size=vocab_size, batch_size=1, rnn_size=config['rnn_size'], layer_depth=config['layer_depth'], num_units=config['num_units'], seq_length=1, keep_prob=config['keep_prob'], grad_clip=config['grad_clip'], rnn_type=config['rnn_type']) with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state("checkpoint/normal") tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) print("Done!") while True: sentence = input() chars = [GO] + list(sentence) + [EOS] fw_ints = [vocab_c2i.get(c, UNK_ID) for c in chars] print(fw_ints) loss, _ = model.get_loss(sess, fw_ints) print("ppl", np.exp(-loss))
def main(_): if len(sys.argv) < 2: print("Please enter a prime") sys.exit() prime = sys.argv[1] prime = prime.decode('utf-8') with open("./log/hyperparams.pkl", 'rb') as f: config = cPickle.load(f) if not os.path.exists(config['checkpoint_dir']): print(" [*] Creating checkpoint directory...") os.makedirs(config['checkpoint_dir']) data_loader = TextLoader( os.path.join(config['data_dir'], config['dataset_name']), config['batch_size'], config['seq_length']) vocab_size = data_loader.vocab_size with tf.variable_scope('model'): model = CharRNN(vocab_size, 1, config['rnn_size'], config['layer_depth'], config['num_units'], 1, config['keep_prob'], config['grad_clip'], is_training=False) with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(config['checkpoint_dir'] + '/' + config['dataset_name']) tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) res = model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 100, prime) print(res)
def main(_): pp.pprint(FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): print(" [*] Creating checkpoint directory...") os.makedirs(FLAGS.checkpoint_dir) data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name), FLAGS.batch_size, FLAGS.seq_length) vocab_size = data_loader.vocab_size with tf.variable_scope(FLAGS.dataset_name): train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, FLAGS.seq_length, FLAGS.keep_prob, FLAGS.grad_clip) with tf.variable_scope(FLAGS.dataset_name, reuse=True): valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, FLAGS.seq_length, FLAGS.keep_prob, FLAGS.grad_clip) with tf.Session() as sess: tf.global_variables_initializer().run() train_model.load(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name) best_val_pp = float('inf') best_val_epoch = 0 valid_loss = 0 valid_perplexity = 0 start = time.time() if FLAGS.export: print("Eval...") final_embeddings = train_model.embedding.eval(sess) emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy') print("Embedding shape: {}".format(final_embeddings.shape)) np.save(emb_file, final_embeddings) else: if not os.path.exists(FLAGS.log_dir): os.makedirs(FLAGS.log_dir) with open( FLAGS.log_dir + "/" + FLAGS.dataset_name + "_hyperparams.pkl", 'wb') as f: cPickle.dump(FLAGS.__flags, f) for e in range(FLAGS.num_epochs): data_loader.reset_batch_pointer() sess.run(tf.assign(train_model.lr, FLAGS.learning_rate)) FLAGS.learning_rate /= 2 for b in range(data_loader.num_batches): x, y = data_loader.next_batch() res, time_batch = run_minibatches(sess, x, y, train_model) train_loss = res["loss"] train_perplexity = np.exp(train_loss) print( "{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k" \ .format(data_loader.pointer, data_loader.num_batches, e, train_loss, valid_loss, train_perplexity, valid_perplexity, time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000)) valid_loss = 0 for vb in range(data_loader.num_valid_batches): res, valid_time_batch = run_minibatches( sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False) valid_loss += res["loss"] valid_loss = valid_loss / data_loader.num_valid_batches valid_perplexity = np.exp(valid_loss) print("### valid_perplexity = {:.2f}, time/batch = {:.2f}". format(valid_perplexity, valid_time_batch)) if valid_perplexity < best_val_pp: best_val_pp = valid_perplexity best_val_epoch = e train_model.save(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name) print("model saved to {}".format(FLAGS.checkpoint_dir)) if e - best_val_epoch > FLAGS.early_stopping: print('Total time: {}'.format(time.time() - start)) break
def main(_): pp.pprint(FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): print(" [*] Creating checkpoint directory...") os.makedirs(FLAGS.checkpoint_dir) data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name), FLAGS.batch_size, FLAGS.seq_length) vocab_size = data_loader.vocab_size valid_size = 50 valid_window = 100 with tf.variable_scope('model'): train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, FLAGS.seq_length, FLAGS.keep_prob, FLAGS.grad_clip, FLAGS.nce_samples) learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, train_model.global_step, data_loader.num_batches, FLAGS.grad_clip, staircase=True) with tf.variable_scope('model', reuse=True): simple_model = CharRNN(vocab_size, 1, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, 1, FLAGS.keep_prob, FLAGS.grad_clip) with tf.variable_scope('model', reuse=True): valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, FLAGS.seq_length, FLAGS.keep_prob, FLAGS.grad_clip) with tf.Session() as sess: tf.global_variables_initializer().run() best_val_pp = float('inf') best_val_epoch = 0 valid_loss = 0 valid_perplexity = 0 start = time.time() if FLAGS.export: print("Eval...") final_embeddings = train_model.embedding.eval(sess) emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy') print("Embedding shape: {}".format(final_embeddings.shape)) np.save(emb_file, final_embeddings) else: # Train current_step = 0 similarity, valid_examples, _ = compute_similarity(train_model, valid_size, valid_window, 6) # save hyper-parameters cPickle.dump(FLAGS.__flags, open(FLAGS.log_dir + "/hyperparams.pkl", 'wb')) # run it! for e in range(FLAGS.num_epochs): data_loader.reset_batch_pointer() # decay learning rate sess.run(tf.assign(train_model.lr, learning_rate)) # iterate by batch for b in range(data_loader.num_batches): x, y = data_loader.next_batch() res, time_batch = run_epochs(sess, x, y, train_model) train_loss = res["loss"][0] train_perplexity = np.exp(train_loss) iterate = e * data_loader.num_batches + b if current_step != 0 and current_step % FLAGS.valid_every == 0: valid_loss = 0 for vb in range(data_loader.num_valid_batches): res, valid_time_batch = run_epochs(sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False) valid_loss += res["loss"][0] valid_loss = valid_loss / data_loader.num_valid_batches valid_perplexity = np.exp(valid_loss) print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".format(valid_perplexity, valid_time_batch)) log_str = "" # Generate sample smp1 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"我喜歡做") smp2 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"他吃飯時會用") smp3 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"人類總要重複同樣的") smp4 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"天色暗了,好像快要") log_str = log_str + smp1 + "\n" log_str = log_str + smp2 + "\n" log_str = log_str + smp3 + "\n" log_str = log_str + smp4 + "\n" # Write a similarity log # Note that this is expensive (~20% slowdown if computed every 500 steps) sim = similarity.eval() for i in range(valid_size): valid_word = data_loader.chars[valid_examples[i]] top_k = 8 # number of nearest neighbors nearest = (-sim[i, :]).argsort()[1:top_k+1] log_str = log_str + "Nearest to %s:" % valid_word for k in range(top_k): close_word = data_loader.chars[nearest[k]] log_str = "%s %s," % (log_str, close_word) log_str = log_str + "\n" print(log_str) # Write to log text_file = codecs.open(FLAGS.log_dir + "/similarity.txt", "w", "utf-8") text_file.write(log_str) text_file.close() # print log print("{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k"\ .format(e * data_loader.num_batches + b, FLAGS.num_epochs * data_loader.num_batches, e, train_loss, valid_loss, train_perplexity, valid_perplexity, time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000)) current_step = tf.train.global_step(sess, train_model.global_step) if valid_perplexity < best_val_pp: best_val_pp = valid_perplexity best_val_epoch = iterate # save best model train_model.save(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name) print("model saved to {}".format(FLAGS.checkpoint_dir)) # early_stopping if iterate - best_val_epoch > FLAGS.early_stopping: print('Total time: {}'.format(time.time() - start)) break
def main(_): pp.pprint(FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): print(" [*] Creating checkpoint directory...") os.makedirs(FLAGS.checkpoint_dir) data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name), FLAGS.batch_size, FLAGS.seq_length) vocab_size = data_loader.vocab_size valid_size = 50 valid_window = 100 with tf.variable_scope('model'): train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, FLAGS.seq_length, FLAGS.keep_prob, FLAGS.grad_clip) with tf.variable_scope('model', reuse=True): simple_model = CharRNN(vocab_size, 1, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, 1, FLAGS.keep_prob, FLAGS.grad_clip) with tf.variable_scope('model', reuse=True): valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size, FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type, FLAGS.seq_length, FLAGS.keep_prob, FLAGS.grad_clip) with tf.Session() as sess: tf.global_variables_initializer().run() train_model.load(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name) best_val_pp = float('inf') best_val_epoch = 0 valid_loss = 0 valid_perplexity = 0 start = time.time() if FLAGS.export: print("Eval...") final_embeddings = train_model.embedding.eval(sess) emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy') print("Embedding shape: {}".format(final_embeddings.shape)) np.save(emb_file, final_embeddings) else: # Train current_step = 0 similarity, valid_examples, _ = compute_similarity(train_model, valid_size, valid_window, 6) # save hyper-parameters cPickle.dump(FLAGS.__flags, open(FLAGS.log_dir + "/hyperparams.pkl", 'wb')) # run it! for e in range(FLAGS.num_epochs): data_loader.reset_batch_pointer() # decay learning rate sess.run(tf.assign(train_model.lr, FLAGS.learning_rate)) # iterate by batch for b in range(data_loader.num_batches): x, y = data_loader.next_batch() res, time_batch = run_epochs(sess, x, y, train_model) train_loss = res["loss"] train_perplexity = np.exp(train_loss) iterate = e * data_loader.num_batches + b # print log print("{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k"\ .format(e * data_loader.num_batches + b, FLAGS.num_epochs * data_loader.num_batches, e, train_loss, valid_loss, train_perplexity, valid_perplexity, time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000)) current_step = tf.train.global_step(sess, train_model.global_step) # validate valid_loss = 0 for vb in range(data_loader.num_valid_batches): res, valid_time_batch = run_epochs(sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False) valid_loss += res["loss"] valid_loss = valid_loss / data_loader.num_valid_batches valid_perplexity = np.exp(valid_loss) print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".format(valid_perplexity, valid_time_batch)) log_str = "" # Generate sample smp1 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"我喜歡做") smp2 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"他吃飯時會用") smp3 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"人類總要重複同樣的") smp4 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"天色暗了,好像快要") log_str = log_str + smp1 + "\n" log_str = log_str + smp2 + "\n" log_str = log_str + smp3 + "\n" log_str = log_str + smp4 + "\n" # Write a similarity log # Note that this is expensive (~20% slowdown if computed every 500 steps) sim = similarity.eval() for i in range(valid_size): valid_word = data_loader.chars[valid_examples[i]] top_k = 8 # number of nearest neighbors nearest = (-sim[i, :]).argsort()[1:top_k+1] log_str = log_str + "Nearest to %s:" % valid_word for k in range(top_k): close_word = data_loader.chars[nearest[k]] log_str = "%s %s," % (log_str, close_word) log_str = log_str + "\n" print(log_str) # Write to log text_file = codecs.open(FLAGS.log_dir + "/similarity.txt", "w", "utf-8") text_file.write(log_str) text_file.close() if valid_perplexity < best_val_pp: best_val_pp = valid_perplexity best_val_epoch = iterate # save best model train_model.save(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name) print("model saved to {}".format(FLAGS.checkpoint_dir)) # early_stopping if iterate - best_val_epoch > FLAGS.early_stopping: print('Total time: {}'.format(time.time() - start)) break
def __init__(self, fw_hyp_path, bw_hyp_path, fw_vocab_path, bw_vocab_path, fw_model_path, bw_model_path, dictionary_path, threshold=math.exp(-6.8)): ''' Load solver :param fw_hyp_path: forward model hyperparam path :param bw_hyp_path: backward model hyperparam path :param fw_vocab_path: forward model vocab path :param bw_vocab_path: backward model vocab path :param dictionary_path: dictionary path :param threshold: threshold for model ''' jieba.load_userdict(dictionary_path) self.threshold = np.log(threshold) # load configs with open(fw_hyp_path, 'rb') as f: fw_hyp_config = pickle.load(f) with open(bw_hyp_path, 'rb') as f: bw_hyp_config = pickle.load(f) # load vocabularys with open(fw_vocab_path, 'rb') as f: self.fw_vocab_i2c = pickle.load(f) self.fw_vocab_size = len(self.fw_vocab_i2c) self.fw_vocab_c2i = dict(zip(self.fw_vocab_i2c, range(self.fw_vocab_size))) with open(bw_vocab_path, 'rb') as f: self.bw_vocab_i2c = pickle.load(f) self.bw_vocab_size = len(self.bw_vocab_i2c) self.bw_vocab_c2i = dict(zip(self.bw_vocab_i2c, range(self.bw_vocab_size))) # load fwmodel g1 = tf.Graph() self.fw_sess = tf.Session(graph=g1) with self.fw_sess.as_default(): with g1.as_default(): with tf.variable_scope(fw_hyp_config['dataset_name']): self.fw_model = CharRNN(vocab_size=self.fw_vocab_size, batch_size=1, rnn_size=fw_hyp_config['rnn_size'], layer_depth=fw_hyp_config['layer_depth'], num_units=fw_hyp_config['num_units'], seq_length=1, keep_prob=fw_hyp_config['keep_prob'], grad_clip=fw_hyp_config['grad_clip'], rnn_type=fw_hyp_config['rnn_type']) ckpt = tf.train.get_checkpoint_state(fw_model_path + '/' + fw_hyp_config['dataset_name']) tf.train.Saver().restore(self.fw_sess, ckpt.model_checkpoint_path) # print("fwmodel done!") # load bwmodel g2 = tf.Graph() self.bw_sess = tf.Session(graph=g2) with self.bw_sess.as_default(): with g2.as_default(): with tf.variable_scope(bw_hyp_config['dataset_name']): self.bw_model = CharRNN(vocab_size=self.bw_vocab_size, batch_size=1, rnn_size=bw_hyp_config['rnn_size'], layer_depth=bw_hyp_config['layer_depth'], num_units=bw_hyp_config['num_units'], seq_length=1, keep_prob=bw_hyp_config['keep_prob'], grad_clip=bw_hyp_config['grad_clip'], rnn_type=bw_hyp_config['rnn_type']) ckpt = tf.train.get_checkpoint_state(bw_model_path + '/' + bw_hyp_config['dataset_name']) tf.train.Saver().restore(self.bw_sess, ckpt.model_checkpoint_path) # print("bwmodel done!") # load dictionary with open(dictionary_path, "r", encoding="utf-8") as f: self.dictionary = set() self.word_max_length = 0 for line in f: line = line.strip() if len(line) == 0: continue self.dictionary.add(line) if self.word_max_length < len(line): self.word_max_length = len(line)
class LanguageCorrector(): ''' Natural Language Correction Model ''' def __init__(self, fw_hyp_path, bw_hyp_path, fw_vocab_path, bw_vocab_path, fw_model_path, bw_model_path, dictionary_path, threshold=math.exp(-6.8)): ''' Load solver :param fw_hyp_path: forward model hyperparam path :param bw_hyp_path: backward model hyperparam path :param fw_vocab_path: forward model vocab path :param bw_vocab_path: backward model vocab path :param dictionary_path: dictionary path :param threshold: threshold for model ''' jieba.load_userdict(dictionary_path) self.threshold = np.log(threshold) # load configs with open(fw_hyp_path, 'rb') as f: fw_hyp_config = pickle.load(f) with open(bw_hyp_path, 'rb') as f: bw_hyp_config = pickle.load(f) # load vocabularys with open(fw_vocab_path, 'rb') as f: self.fw_vocab_i2c = pickle.load(f) self.fw_vocab_size = len(self.fw_vocab_i2c) self.fw_vocab_c2i = dict(zip(self.fw_vocab_i2c, range(self.fw_vocab_size))) with open(bw_vocab_path, 'rb') as f: self.bw_vocab_i2c = pickle.load(f) self.bw_vocab_size = len(self.bw_vocab_i2c) self.bw_vocab_c2i = dict(zip(self.bw_vocab_i2c, range(self.bw_vocab_size))) # load fwmodel g1 = tf.Graph() self.fw_sess = tf.Session(graph=g1) with self.fw_sess.as_default(): with g1.as_default(): with tf.variable_scope(fw_hyp_config['dataset_name']): self.fw_model = CharRNN(vocab_size=self.fw_vocab_size, batch_size=1, rnn_size=fw_hyp_config['rnn_size'], layer_depth=fw_hyp_config['layer_depth'], num_units=fw_hyp_config['num_units'], seq_length=1, keep_prob=fw_hyp_config['keep_prob'], grad_clip=fw_hyp_config['grad_clip'], rnn_type=fw_hyp_config['rnn_type']) ckpt = tf.train.get_checkpoint_state(fw_model_path + '/' + fw_hyp_config['dataset_name']) tf.train.Saver().restore(self.fw_sess, ckpt.model_checkpoint_path) # print("fwmodel done!") # load bwmodel g2 = tf.Graph() self.bw_sess = tf.Session(graph=g2) with self.bw_sess.as_default(): with g2.as_default(): with tf.variable_scope(bw_hyp_config['dataset_name']): self.bw_model = CharRNN(vocab_size=self.bw_vocab_size, batch_size=1, rnn_size=bw_hyp_config['rnn_size'], layer_depth=bw_hyp_config['layer_depth'], num_units=bw_hyp_config['num_units'], seq_length=1, keep_prob=bw_hyp_config['keep_prob'], grad_clip=bw_hyp_config['grad_clip'], rnn_type=bw_hyp_config['rnn_type']) ckpt = tf.train.get_checkpoint_state(bw_model_path + '/' + bw_hyp_config['dataset_name']) tf.train.Saver().restore(self.bw_sess, ckpt.model_checkpoint_path) # print("bwmodel done!") # load dictionary with open(dictionary_path, "r", encoding="utf-8") as f: self.dictionary = set() self.word_max_length = 0 for line in f: line = line.strip() if len(line) == 0: continue self.dictionary.add(line) if self.word_max_length < len(line): self.word_max_length = len(line) def correctify(self, sentence): ''' Requested Method :param sentence: input sentence :return: corrections ''' chars = [GO] + list(sentence) + [EOS] bw_chars = chars[::-1] sz = len(chars) fw_ints = [self.fw_vocab_c2i.get(c, UNK_ID) for c in chars] bw_ints = [self.bw_vocab_c2i.get(c, UNK_ID) for c in bw_chars] fw_losses = {} # 过一个字符后的loss bw_losses = {} fw_probs = {} bw_probs = {} # find bad guys bads_or_not = [] bad_pos = set() for i in range(sz): bads_or_not.append(False) # fw side with self.fw_sess.as_default(): with self.fw_sess.graph.as_default(): for i in range(1, sz - 1): fw_substr_ints = fw_ints[:i + 1] fw_losses[i], fw_probs[i] = self.fw_model.get_loss(self.fw_sess, fw_substr_ints) # bw side with self.bw_sess.as_default(): with self.bw_sess.graph.as_default(): for i in range(1, sz - 1): bw_substr_ints = bw_ints[:i + 1] bw_losses[i], bw_probs[i] = self.bw_model.get_loss(self.bw_sess, bw_substr_ints) # first view results = [] for i in range(1, sz - 1): # print(fw_losses[i], bw_losses[sz - 1 - i]) t_loss = fw_losses[i] + bw_losses[sz - 1 - i] # print(t_loss) # print(chars[:i + 1]) results.append([i, t_loss]) results = list(sorted(results, key=lambda x: x[1])) for i in range((len(results) + 1) // 2): score = results[i][1] if score < self.threshold: pos = results[i][0] bads_or_not[pos] = True bad_pos.add(pos) # second view for p in range(sz): if not bads_or_not[p]: continue for word_len in range(2, self.word_max_length): left_p = max(1, p - word_len + 1) right_p = min(sz - 2 - word_len + 1, p) for left in range(left_p, right_p + 1): subword = sentence[left - 1:left - 1 + word_len] if subword in self.dictionary: # print(subword) bads_or_not[p] = False bad_pos.remove(p) if not bads_or_not[p]: break if not bads_or_not[p]: break # print(bad_pos) # find candidates for p in bad_pos: if not sz - 2 >= p >= 1: continue best_ch = "" best_score = np.NINF left_ints = fw_ints[:p] left_sz = len(left_ints) with self.fw_sess.as_default(): with self.fw_sess.graph.as_default(): left_loss, left_probs = self.fw_model.get_loss(self.fw_sess, left_ints) right_ints = bw_ints[:sz - 1 - p] right_sz = len(right_ints) with self.bw_sess.as_default(): with self.bw_sess.graph.as_default(): right_loss, right_probs = self.bw_model.get_loss(self.bw_sess, right_ints) for ic, ch in enumerate(self.fw_vocab_i2c): if ch in START_VOCAB or ch in punctuation or ch in ALL_PUNCTUATION: continue loss = (left_loss * left_sz + math.log(left_probs[ic])) / (left_sz + 1) + \ (right_loss * right_sz + math.log(right_probs[self.bw_vocab_c2i[ch]])) / (right_sz + 1) if loss > best_score + 1e-6: best_score = loss best_ch = ch chars[p] = best_ch chars = "".join(chars[1:-1]) assert len(chars) == len(sentence) segs = jieba.cut(chars) # get requested format ans_array = [] a_p = 0 for seg in segs: s_s = a_p t_t = s_s + len(seg) if chars[s_s:t_t] != sentence[s_s:t_t]: ans_array.append({ "sourceValue": sentence[s_s:t_t], "correctValue": chars[s_s:t_t], "startOffset": s_s, "endOffset": t_t }) a_p = t_t return ans_array