示例#1
0
def main(_):
    with open("./data/normal_hyperparams.pkl", 'rb') as f:
        config = pickle.load(f)

    with open("./data/normal_vocab.pkl", 'rb') as f:
        vocab_i2c = pickle.load(f)
        vocab_size = len(vocab_i2c)
        vocab_c2i = dict(zip(vocab_i2c, range(vocab_size)))

    with tf.variable_scope('normal'):
        model = CharRNN(vocab_size=vocab_size,
                        batch_size=1,
                        rnn_size=config['rnn_size'],
                        layer_depth=config['layer_depth'],
                        num_units=config['num_units'],
                        seq_length=1,
                        keep_prob=config['keep_prob'],
                        grad_clip=config['grad_clip'],
                        rnn_type=config['rnn_type'])

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state("checkpoint/normal")
        tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
        print("Done!")

        while True:
            sentence = input()
            chars = [GO] + list(sentence) + [EOS]
            fw_ints = [vocab_c2i.get(c, UNK_ID) for c in chars]
            print(fw_ints)

            loss, _ = model.get_loss(sess, fw_ints)
            print("ppl", np.exp(-loss))
示例#2
0
def main(_):
    if len(sys.argv) < 2:
        print("Please enter a prime")
        sys.exit()

    prime = sys.argv[1]
    prime = prime.decode('utf-8')

    with open("./log/hyperparams.pkl", 'rb') as f:
        config = cPickle.load(f)

    if not os.path.exists(config['checkpoint_dir']):
        print(" [*] Creating checkpoint directory...")
        os.makedirs(config['checkpoint_dir'])

    data_loader = TextLoader(
        os.path.join(config['data_dir'], config['dataset_name']),
        config['batch_size'], config['seq_length'])
    vocab_size = data_loader.vocab_size

    with tf.variable_scope('model'):
        model = CharRNN(vocab_size,
                        1,
                        config['rnn_size'],
                        config['layer_depth'],
                        config['num_units'],
                        1,
                        config['keep_prob'],
                        config['grad_clip'],
                        is_training=False)

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(config['checkpoint_dir'] + '/' +
                                             config['dataset_name'])
        tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)

        res = model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID,
                           100, prime)

        print(res)
示例#3
0
def main(_):
    pp.pprint(FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        print(" [*] Creating checkpoint directory...")
        os.makedirs(FLAGS.checkpoint_dir)

    data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name),
                             FLAGS.batch_size, FLAGS.seq_length)
    vocab_size = data_loader.vocab_size

    with tf.variable_scope(FLAGS.dataset_name):
        train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                              FLAGS.layer_depth, FLAGS.num_units,
                              FLAGS.rnn_type, FLAGS.seq_length,
                              FLAGS.keep_prob, FLAGS.grad_clip)

    with tf.variable_scope(FLAGS.dataset_name, reuse=True):
        valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                              FLAGS.layer_depth, FLAGS.num_units,
                              FLAGS.rnn_type, FLAGS.seq_length,
                              FLAGS.keep_prob, FLAGS.grad_clip)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()

        train_model.load(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name)

        best_val_pp = float('inf')
        best_val_epoch = 0
        valid_loss = 0
        valid_perplexity = 0
        start = time.time()

        if FLAGS.export:
            print("Eval...")
            final_embeddings = train_model.embedding.eval(sess)
            emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name,
                                    'emb.npy')
            print("Embedding shape: {}".format(final_embeddings.shape))
            np.save(emb_file, final_embeddings)

        else:
            if not os.path.exists(FLAGS.log_dir):
                os.makedirs(FLAGS.log_dir)
            with open(
                    FLAGS.log_dir + "/" + FLAGS.dataset_name +
                    "_hyperparams.pkl", 'wb') as f:
                cPickle.dump(FLAGS.__flags, f)
            for e in range(FLAGS.num_epochs):
                data_loader.reset_batch_pointer()
                sess.run(tf.assign(train_model.lr, FLAGS.learning_rate))
                FLAGS.learning_rate /= 2
                for b in range(data_loader.num_batches):
                    x, y = data_loader.next_batch()
                    res, time_batch = run_minibatches(sess, x, y, train_model)
                    train_loss = res["loss"]
                    train_perplexity = np.exp(train_loss)
                    print(
                        "{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k" \
                            .format(data_loader.pointer, data_loader.num_batches,
                                    e,
                                    train_loss, valid_loss,
                                    train_perplexity, valid_perplexity,
                                    time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000))
                valid_loss = 0
                for vb in range(data_loader.num_valid_batches):
                    res, valid_time_batch = run_minibatches(
                        sess, data_loader.x_valid[vb], data_loader.y_valid[vb],
                        valid_model, False)
                    valid_loss += res["loss"]
                valid_loss = valid_loss / data_loader.num_valid_batches
                valid_perplexity = np.exp(valid_loss)
                print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".
                      format(valid_perplexity, valid_time_batch))
                if valid_perplexity < best_val_pp:
                    best_val_pp = valid_perplexity
                    best_val_epoch = e
                    train_model.save(sess, FLAGS.checkpoint_dir,
                                     FLAGS.dataset_name)
                    print("model saved to {}".format(FLAGS.checkpoint_dir))
                if e - best_val_epoch > FLAGS.early_stopping:
                    print('Total time: {}'.format(time.time() - start))
                    break
示例#4
0
def main(_):
  pp.pprint(FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    print(" [*] Creating checkpoint directory...")
    os.makedirs(FLAGS.checkpoint_dir)

  data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name),
                           FLAGS.batch_size, FLAGS.seq_length)
  vocab_size = data_loader.vocab_size
  valid_size = 50
  valid_window = 100

  with tf.variable_scope('model'):
    train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip, FLAGS.nce_samples)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, train_model.global_step,
                                               data_loader.num_batches, FLAGS.grad_clip,
                                               staircase=True)
  with tf.variable_scope('model', reuse=True):
    simple_model = CharRNN(vocab_size, 1, FLAGS.rnn_size,
                           FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                           1, FLAGS.keep_prob,
                           FLAGS.grad_clip)

  with tf.variable_scope('model', reuse=True):
    valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip)

  with tf.Session() as sess:
    tf.global_variables_initializer().run()

    best_val_pp = float('inf')
    best_val_epoch = 0
    valid_loss = 0
    valid_perplexity = 0
    start = time.time()

    if FLAGS.export:
      print("Eval...")
      final_embeddings = train_model.embedding.eval(sess)
      emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy')
      print("Embedding shape: {}".format(final_embeddings.shape))
      np.save(emb_file, final_embeddings)

    else: # Train
      current_step = 0
      similarity, valid_examples, _ = compute_similarity(train_model, valid_size, valid_window, 6)

      # save hyper-parameters
      cPickle.dump(FLAGS.__flags, open(FLAGS.log_dir + "/hyperparams.pkl", 'wb'))

      # run it!
      for e in range(FLAGS.num_epochs):
        data_loader.reset_batch_pointer()

        # decay learning rate
        sess.run(tf.assign(train_model.lr, learning_rate))

        # iterate by batch
        for b in range(data_loader.num_batches):
          x, y = data_loader.next_batch()
          res, time_batch = run_epochs(sess, x, y, train_model)
          train_loss = res["loss"][0]
          train_perplexity = np.exp(train_loss)
          iterate = e * data_loader.num_batches + b

          if current_step != 0 and current_step % FLAGS.valid_every == 0:
            valid_loss = 0

            for vb in range(data_loader.num_valid_batches):
              res, valid_time_batch = run_epochs(sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False)
              valid_loss += res["loss"][0]

            valid_loss = valid_loss / data_loader.num_valid_batches
            valid_perplexity = np.exp(valid_loss)

            print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".format(valid_perplexity, valid_time_batch))

            log_str = ""

            # Generate sample
            smp1 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"我喜歡做")
            smp2 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"他吃飯時會用")
            smp3 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"人類總要重複同樣的")
            smp4 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"天色暗了,好像快要")

            log_str = log_str + smp1 + "\n"
            log_str = log_str + smp2 + "\n"
            log_str = log_str + smp3 + "\n"
            log_str = log_str + smp4 + "\n"

            # Write a similarity log
            # Note that this is expensive (~20% slowdown if computed every 500 steps)
            sim = similarity.eval()
            for i in range(valid_size):
              valid_word = data_loader.chars[valid_examples[i]]
              top_k = 8 # number of nearest neighbors
              nearest = (-sim[i, :]).argsort()[1:top_k+1]
              log_str = log_str + "Nearest to %s:" % valid_word
              for k in range(top_k):
                close_word = data_loader.chars[nearest[k]]
                log_str = "%s %s," % (log_str, close_word)
              log_str = log_str + "\n"
            print(log_str)
            # Write to log
            text_file = codecs.open(FLAGS.log_dir + "/similarity.txt", "w", "utf-8")
            text_file.write(log_str)
            text_file.close()

          # print log
          print("{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k"\
              .format(e * data_loader.num_batches + b,
                      FLAGS.num_epochs * data_loader.num_batches,
                      e, train_loss, valid_loss, train_perplexity, valid_perplexity,
                      time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000))

          current_step = tf.train.global_step(sess, train_model.global_step)

        if valid_perplexity < best_val_pp:
          best_val_pp = valid_perplexity
          best_val_epoch = iterate

          # save best model
          train_model.save(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name)
          print("model saved to {}".format(FLAGS.checkpoint_dir))

        # early_stopping
        if iterate - best_val_epoch > FLAGS.early_stopping:
          print('Total time: {}'.format(time.time() - start))
          break
示例#5
0
def main(_):
  pp.pprint(FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    print(" [*] Creating checkpoint directory...")
    os.makedirs(FLAGS.checkpoint_dir)

  data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name),
                           FLAGS.batch_size, FLAGS.seq_length)
  vocab_size = data_loader.vocab_size
  valid_size = 50
  valid_window = 100

  with tf.variable_scope('model'):
    train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip)

  with tf.variable_scope('model', reuse=True):
    simple_model = CharRNN(vocab_size, 1, FLAGS.rnn_size,
                           FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                           1, FLAGS.keep_prob,
                           FLAGS.grad_clip)

  with tf.variable_scope('model', reuse=True):
    valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip)

  with tf.Session() as sess:
    tf.global_variables_initializer().run()

    train_model.load(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name)

    best_val_pp = float('inf')
    best_val_epoch = 0
    valid_loss = 0
    valid_perplexity = 0
    start = time.time()

    if FLAGS.export:
      print("Eval...")
      final_embeddings = train_model.embedding.eval(sess)
      emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy')
      print("Embedding shape: {}".format(final_embeddings.shape))
      np.save(emb_file, final_embeddings)

    else: # Train
      current_step = 0
      similarity, valid_examples, _ = compute_similarity(train_model, valid_size, valid_window, 6)

      # save hyper-parameters
      cPickle.dump(FLAGS.__flags, open(FLAGS.log_dir + "/hyperparams.pkl", 'wb'))

      # run it!
      for e in range(FLAGS.num_epochs):
        data_loader.reset_batch_pointer()

        # decay learning rate
        sess.run(tf.assign(train_model.lr, FLAGS.learning_rate))

        # iterate by batch
        for b in range(data_loader.num_batches):
          x, y = data_loader.next_batch()
          res, time_batch = run_epochs(sess, x, y, train_model)
          train_loss = res["loss"]
          train_perplexity = np.exp(train_loss)
          iterate = e * data_loader.num_batches + b

          # print log
          print("{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k"\
              .format(e * data_loader.num_batches + b,
                      FLAGS.num_epochs * data_loader.num_batches,
                      e, train_loss, valid_loss, train_perplexity, valid_perplexity,
                      time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000))

          current_step = tf.train.global_step(sess, train_model.global_step)

        # validate
        valid_loss = 0

        for vb in range(data_loader.num_valid_batches):
          res, valid_time_batch = run_epochs(sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False)
          valid_loss += res["loss"]

        valid_loss = valid_loss / data_loader.num_valid_batches
        valid_perplexity = np.exp(valid_loss)

        print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".format(valid_perplexity, valid_time_batch))

        log_str = ""

        # Generate sample
        smp1 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"我喜歡做")
        smp2 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"他吃飯時會用")
        smp3 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"人類總要重複同樣的")
        smp4 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"天色暗了,好像快要")

        log_str = log_str + smp1 + "\n"
        log_str = log_str + smp2 + "\n"
        log_str = log_str + smp3 + "\n"
        log_str = log_str + smp4 + "\n"

        # Write a similarity log
        # Note that this is expensive (~20% slowdown if computed every 500 steps)
        sim = similarity.eval()
        for i in range(valid_size):
          valid_word = data_loader.chars[valid_examples[i]]
          top_k = 8 # number of nearest neighbors
          nearest = (-sim[i, :]).argsort()[1:top_k+1]
          log_str = log_str + "Nearest to %s:" % valid_word
          for k in range(top_k):
            close_word = data_loader.chars[nearest[k]]
            log_str = "%s %s," % (log_str, close_word)
          log_str = log_str + "\n"
        print(log_str)

        # Write to log
        text_file = codecs.open(FLAGS.log_dir + "/similarity.txt", "w", "utf-8")
        text_file.write(log_str)
        text_file.close()

        if valid_perplexity < best_val_pp:
          best_val_pp = valid_perplexity
          best_val_epoch = iterate

          # save best model
          train_model.save(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name)
          print("model saved to {}".format(FLAGS.checkpoint_dir))

        # early_stopping
        if iterate - best_val_epoch > FLAGS.early_stopping:
          print('Total time: {}'.format(time.time() - start))
          break
示例#6
0
    def __init__(self, fw_hyp_path, bw_hyp_path, fw_vocab_path, bw_vocab_path, fw_model_path, bw_model_path,
                 dictionary_path, threshold=math.exp(-6.8)):
        '''
        Load solver
        :param fw_hyp_path: forward model hyperparam path
        :param bw_hyp_path: backward model hyperparam path
        :param fw_vocab_path: forward model vocab path
        :param bw_vocab_path: backward model vocab path
        :param dictionary_path: dictionary path
        :param threshold: threshold for model
        '''
        jieba.load_userdict(dictionary_path)

        self.threshold = np.log(threshold)

        # load configs
        with open(fw_hyp_path, 'rb') as f:
            fw_hyp_config = pickle.load(f)
        with open(bw_hyp_path, 'rb') as f:
            bw_hyp_config = pickle.load(f)

        # load vocabularys
        with open(fw_vocab_path, 'rb') as f:
            self.fw_vocab_i2c = pickle.load(f)
            self.fw_vocab_size = len(self.fw_vocab_i2c)
            self.fw_vocab_c2i = dict(zip(self.fw_vocab_i2c, range(self.fw_vocab_size)))
        with open(bw_vocab_path, 'rb') as f:
            self.bw_vocab_i2c = pickle.load(f)
            self.bw_vocab_size = len(self.bw_vocab_i2c)
            self.bw_vocab_c2i = dict(zip(self.bw_vocab_i2c, range(self.bw_vocab_size)))

        # load fwmodel
        g1 = tf.Graph()
        self.fw_sess = tf.Session(graph=g1)
        with self.fw_sess.as_default():
            with g1.as_default():
                with tf.variable_scope(fw_hyp_config['dataset_name']):
                    self.fw_model = CharRNN(vocab_size=self.fw_vocab_size,
                                            batch_size=1,
                                            rnn_size=fw_hyp_config['rnn_size'],
                                            layer_depth=fw_hyp_config['layer_depth'],
                                            num_units=fw_hyp_config['num_units'],
                                            seq_length=1,
                                            keep_prob=fw_hyp_config['keep_prob'],
                                            grad_clip=fw_hyp_config['grad_clip'],
                                            rnn_type=fw_hyp_config['rnn_type'])
                ckpt = tf.train.get_checkpoint_state(fw_model_path +
                                                     '/' +
                                                     fw_hyp_config['dataset_name'])
                tf.train.Saver().restore(self.fw_sess, ckpt.model_checkpoint_path)
        # print("fwmodel done!")

        # load bwmodel
        g2 = tf.Graph()
        self.bw_sess = tf.Session(graph=g2)
        with self.bw_sess.as_default():
            with g2.as_default():
                with tf.variable_scope(bw_hyp_config['dataset_name']):
                    self.bw_model = CharRNN(vocab_size=self.bw_vocab_size,
                                            batch_size=1,
                                            rnn_size=bw_hyp_config['rnn_size'],
                                            layer_depth=bw_hyp_config['layer_depth'],
                                            num_units=bw_hyp_config['num_units'],
                                            seq_length=1,
                                            keep_prob=bw_hyp_config['keep_prob'],
                                            grad_clip=bw_hyp_config['grad_clip'],
                                            rnn_type=bw_hyp_config['rnn_type'])
                ckpt = tf.train.get_checkpoint_state(bw_model_path +
                                                     '/' +
                                                     bw_hyp_config['dataset_name'])
                tf.train.Saver().restore(self.bw_sess, ckpt.model_checkpoint_path)
        # print("bwmodel done!")

        # load dictionary
        with open(dictionary_path, "r", encoding="utf-8") as f:
            self.dictionary = set()
            self.word_max_length = 0
            for line in f:
                line = line.strip()
                if len(line) == 0:
                    continue
                self.dictionary.add(line)
                if self.word_max_length < len(line):
                    self.word_max_length = len(line)
示例#7
0
class LanguageCorrector():
    '''
    Natural Language Correction Model
    '''

    def __init__(self, fw_hyp_path, bw_hyp_path, fw_vocab_path, bw_vocab_path, fw_model_path, bw_model_path,
                 dictionary_path, threshold=math.exp(-6.8)):
        '''
        Load solver
        :param fw_hyp_path: forward model hyperparam path
        :param bw_hyp_path: backward model hyperparam path
        :param fw_vocab_path: forward model vocab path
        :param bw_vocab_path: backward model vocab path
        :param dictionary_path: dictionary path
        :param threshold: threshold for model
        '''
        jieba.load_userdict(dictionary_path)

        self.threshold = np.log(threshold)

        # load configs
        with open(fw_hyp_path, 'rb') as f:
            fw_hyp_config = pickle.load(f)
        with open(bw_hyp_path, 'rb') as f:
            bw_hyp_config = pickle.load(f)

        # load vocabularys
        with open(fw_vocab_path, 'rb') as f:
            self.fw_vocab_i2c = pickle.load(f)
            self.fw_vocab_size = len(self.fw_vocab_i2c)
            self.fw_vocab_c2i = dict(zip(self.fw_vocab_i2c, range(self.fw_vocab_size)))
        with open(bw_vocab_path, 'rb') as f:
            self.bw_vocab_i2c = pickle.load(f)
            self.bw_vocab_size = len(self.bw_vocab_i2c)
            self.bw_vocab_c2i = dict(zip(self.bw_vocab_i2c, range(self.bw_vocab_size)))

        # load fwmodel
        g1 = tf.Graph()
        self.fw_sess = tf.Session(graph=g1)
        with self.fw_sess.as_default():
            with g1.as_default():
                with tf.variable_scope(fw_hyp_config['dataset_name']):
                    self.fw_model = CharRNN(vocab_size=self.fw_vocab_size,
                                            batch_size=1,
                                            rnn_size=fw_hyp_config['rnn_size'],
                                            layer_depth=fw_hyp_config['layer_depth'],
                                            num_units=fw_hyp_config['num_units'],
                                            seq_length=1,
                                            keep_prob=fw_hyp_config['keep_prob'],
                                            grad_clip=fw_hyp_config['grad_clip'],
                                            rnn_type=fw_hyp_config['rnn_type'])
                ckpt = tf.train.get_checkpoint_state(fw_model_path +
                                                     '/' +
                                                     fw_hyp_config['dataset_name'])
                tf.train.Saver().restore(self.fw_sess, ckpt.model_checkpoint_path)
        # print("fwmodel done!")

        # load bwmodel
        g2 = tf.Graph()
        self.bw_sess = tf.Session(graph=g2)
        with self.bw_sess.as_default():
            with g2.as_default():
                with tf.variable_scope(bw_hyp_config['dataset_name']):
                    self.bw_model = CharRNN(vocab_size=self.bw_vocab_size,
                                            batch_size=1,
                                            rnn_size=bw_hyp_config['rnn_size'],
                                            layer_depth=bw_hyp_config['layer_depth'],
                                            num_units=bw_hyp_config['num_units'],
                                            seq_length=1,
                                            keep_prob=bw_hyp_config['keep_prob'],
                                            grad_clip=bw_hyp_config['grad_clip'],
                                            rnn_type=bw_hyp_config['rnn_type'])
                ckpt = tf.train.get_checkpoint_state(bw_model_path +
                                                     '/' +
                                                     bw_hyp_config['dataset_name'])
                tf.train.Saver().restore(self.bw_sess, ckpt.model_checkpoint_path)
        # print("bwmodel done!")

        # load dictionary
        with open(dictionary_path, "r", encoding="utf-8") as f:
            self.dictionary = set()
            self.word_max_length = 0
            for line in f:
                line = line.strip()
                if len(line) == 0:
                    continue
                self.dictionary.add(line)
                if self.word_max_length < len(line):
                    self.word_max_length = len(line)

    def correctify(self, sentence):
        '''
        Requested Method
        :param sentence: input sentence
        :return: corrections
        '''
        chars = [GO] + list(sentence) + [EOS]
        bw_chars = chars[::-1]
        sz = len(chars)
        fw_ints = [self.fw_vocab_c2i.get(c, UNK_ID) for c in chars]
        bw_ints = [self.bw_vocab_c2i.get(c, UNK_ID) for c in bw_chars]

        fw_losses = {}  # 过一个字符后的loss
        bw_losses = {}
        fw_probs = {}
        bw_probs = {}

        # find bad guys
        bads_or_not = []
        bad_pos = set()
        for i in range(sz):
            bads_or_not.append(False)
        # fw side
        with self.fw_sess.as_default():
            with self.fw_sess.graph.as_default():
                for i in range(1, sz - 1):
                    fw_substr_ints = fw_ints[:i + 1]
                    fw_losses[i], fw_probs[i] = self.fw_model.get_loss(self.fw_sess, fw_substr_ints)

        # bw side
        with self.bw_sess.as_default():
            with self.bw_sess.graph.as_default():
                for i in range(1, sz - 1):
                    bw_substr_ints = bw_ints[:i + 1]
                    bw_losses[i], bw_probs[i] = self.bw_model.get_loss(self.bw_sess, bw_substr_ints)

        # first view
        results = []
        for i in range(1, sz - 1):
            # print(fw_losses[i], bw_losses[sz - 1 - i])
            t_loss = fw_losses[i] + bw_losses[sz - 1 - i]
            # print(t_loss)
            # print(chars[:i + 1])
            results.append([i, t_loss])
        results = list(sorted(results, key=lambda x: x[1]))
        for i in range((len(results) + 1) // 2):
            score = results[i][1]
            if score < self.threshold:
                pos = results[i][0]
                bads_or_not[pos] = True
                bad_pos.add(pos)

        # second view
        for p in range(sz):
            if not bads_or_not[p]:
                continue
            for word_len in range(2, self.word_max_length):
                left_p = max(1, p - word_len + 1)
                right_p = min(sz - 2 - word_len + 1, p)
                for left in range(left_p, right_p + 1):
                    subword = sentence[left - 1:left - 1 + word_len]
                    if subword in self.dictionary:
                        # print(subword)
                        bads_or_not[p] = False
                        bad_pos.remove(p)
                    if not bads_or_not[p]:
                        break
                if not bads_or_not[p]:
                    break
        # print(bad_pos)

        # find candidates
        for p in bad_pos:
            if not sz - 2 >= p >= 1:
                continue
            best_ch = ""
            best_score = np.NINF

            left_ints = fw_ints[:p]
            left_sz = len(left_ints)
            with self.fw_sess.as_default():
                with self.fw_sess.graph.as_default():
                    left_loss, left_probs = self.fw_model.get_loss(self.fw_sess, left_ints)

            right_ints = bw_ints[:sz - 1 - p]
            right_sz = len(right_ints)
            with self.bw_sess.as_default():
                with self.bw_sess.graph.as_default():
                    right_loss, right_probs = self.bw_model.get_loss(self.bw_sess, right_ints)

            for ic, ch in enumerate(self.fw_vocab_i2c):
                if ch in START_VOCAB or ch in punctuation or ch in ALL_PUNCTUATION:
                    continue

                loss = (left_loss * left_sz + math.log(left_probs[ic])) / (left_sz + 1) + \
                       (right_loss * right_sz + math.log(right_probs[self.bw_vocab_c2i[ch]])) / (right_sz + 1)

                if loss > best_score + 1e-6:
                    best_score = loss
                    best_ch = ch
            chars[p] = best_ch
        chars = "".join(chars[1:-1])
        assert len(chars) == len(sentence)
        segs = jieba.cut(chars)

        # get requested format
        ans_array = []
        a_p = 0
        for seg in segs:
            s_s = a_p
            t_t = s_s + len(seg)
            if chars[s_s:t_t] != sentence[s_s:t_t]:
                ans_array.append({
                    "sourceValue": sentence[s_s:t_t],
                    "correctValue": chars[s_s:t_t],
                    "startOffset": s_s,
                    "endOffset": t_t
                })
            a_p = t_t
        return ans_array