示例#1
0
def eval(flags):
    train_data_df = load_data_from_csv(flags.train_csv_file)
    columns = train_data_df.columns.values.tolist()
    hparams = load_hparams(
        flags.checkpoint_dir, {
            "mode": 'eval',
            'checkpoint_dir': flags.checkpoint_dir + "/best_dev",
            "batch_size": flags.batch_size
        })

    save_hparams(flags.checkpoint_dir, hparams)
    for column in columns[2:]:
        dataset = DataSet(flags.data_file, flags.train_csv_file, column,
                          flags.batch_size, flags.vocab_file,
                          flags.max_sent_in_doc, flags.max_word_in_sent)
        with tf.Session(config=get_config_proto(
                log_device_placement=False)) as sess:
            model = HAN(hparams)
            try:
                model.restore_model(sess)  #restore best solution
            except Exception as e:
                print("unable to restore model with exception", e)
                exit(1)

            checkpoint_loss = 0.0
            for i, (x, y) in enumerate(dataset.get_next()):
                batch_loss, accuracy = model.eval_one_batch(sess, x, y)
                checkpoint_loss += batch_loss * flags.batch_size
                print("# batch {0}/{1}, loss {2},accuracy{3}".format(
                    i, flags.batch_size, checkpoint_loss, accuracy))
            return batch_loss, accuracy
示例#2
0
文件: predict.py 项目: mindis/KBQA
class RelationMatcher:
    def __init__(self, dir_path):
        checkpoint_dir = os.path.abspath(os.path.join(dir_path, "checkpoints"))
        save_path = os.path.join(checkpoint_dir, "model")
        config_path = os.path.join(dir_path, 'config.json')
        parameters = json.load(open(config_path))

        parameters['reload'] = True
        parameters['load_path'] = save_path
        with tf.Graph().as_default():
            self.model = RelationMatcherModel(parameters)
        self.dataset = DataSet(parameters)

    def get_match_score(self, pattern, relation):
        data = self.dataset.create_model_input([pattern], [relation])
        scores = self.model.predict(data['word_ids'], data['sentence_lengths'],
                                    data['char_ids'], data['word_lengths'],
                                    data['relation_ids'],
                                    data['pattern_positions'],
                                    data['relation_positions'])
        return scores[0]

    def get_batch_match_score(self, patterns, relations):
        # TODO: if number of relation is big, compute score in batches
        data = self.dataset.create_model_input(patterns, relations)
        scores, pattern_repr, relation_repr = self.model.predict(
            data['word_ids'],
            data['sentence_lengths'],
            data['char_ids'],
            data['word_lengths'],
            data['relation_ids'],
            data['pattern_positions'],
            data['relation_positions'],
            include_repr=True)
        return scores, pattern_repr, relation_repr
示例#3
0
 def __init__(self, dir_path):
     checkpoint_dir = os.path.abspath(os.path.join(dir_path, "checkpoints"))
     save_path = os.path.join(checkpoint_dir, "model")
     config_path = os.path.join(dir_path, 'config.json')
     parameters = json.load(open(config_path))
     parameters['reload'] = True
     parameters['load_path'] = save_path
     with tf.Graph().as_default():
         self.model = BetaRankerModel(parameters)
     self.dataset = DataSet(parameters)
示例#4
0
class BetaRanker:
    def __init__(self, dir_path):
        checkpoint_dir = os.path.abspath(os.path.join(dir_path, "checkpoints"))
        save_path = os.path.join(checkpoint_dir, "model")
        config_path = os.path.join(dir_path, 'config.json')
        parameters = json.load(open(config_path))
        parameters['reload'] = True
        parameters['load_path'] = save_path
        with tf.Graph().as_default():
            self.model = BetaRankerModel(parameters)
        self.dataset = DataSet(parameters)

    def rank_queries(self, queries):
        data = self.dataset.create_model_input(queries)
        scores, _ = self.model.predict(
            data['pattern_word_ids'],
            data['sentence_lengths'],
            None,  # TODO: support pattern char-based feature
            None,
            data['relation_ids'],
            data['relation_lengths'],
            data['mention_char_ids'],
            data['topic_char_ids'],
            data['mention_lengths'],
            data['topic_lengths'],
            data['question_word_ids'],
            data['question_lengths'],
            data['type_ids'],
            data['type_lengths'],
            data['answer_type_ids'],
            data['answer_type_weights'],
            data['qword_ids'],
            data['extras'])
        return dict(zip(data['hash'], scores))
示例#5
0
def translate(dir_path):
    checkpoint_dir = os.path.abspath(os.path.join(dir_path, "checkpoints"))
    save_path = os.path.join(checkpoint_dir, "model")
    config_path = os.path.join(dir_path, 'config.json')
    params = json.load(open(config_path))
    params['reload'] = True
    params['load_path'] = save_path
    with tf.Graph().as_default():
        model = Model(params)
    dataset = DataSet(params['dictionaries'], params['max_len'],
                      params['source_vocab_size'], params['target_vocab_size'])

    sentences = [
        'Access to the Internet is itself a fundamental right .',
        'These schools then subsidised free education for the non - working poor .',
        'What will they do ? CSSD lacks knowledge of both Voldemort and candy bars in Prague .'
    ]
    xs, x_mask = dataset.create_model_input(sentences)

    for x, mask in zip(xs, x_mask):
        print x, mask
        print dataset.to_readable_source(x)
        ys, scores = model.predict(x[None, :], mask[None, :], 10,
                                   params['max_len'])
        for i, y in enumerate(ys):
            print dataset.to_readable(y), scores[i]
示例#6
0
文件: predict.py 项目: mindis/KBQA
    def __init__(self, dir_path):

        checkpoint_dir = os.path.abspath(os.path.join(dir_path, "checkpoints"))
        save_path = os.path.join(checkpoint_dir, "model")
        config_path = os.path.join(dir_path, 'config.json')
        parameters = json.load(open(config_path))
        with tf.Graph().as_default():
            self.model = DeepCRF(
                parameters['max_sentence_len'], parameters['max_word_len'],
                parameters['char_dim'], parameters['char_rnn_dim'],
                parameters['char_bidirect'] == 1, parameters['word_dim'],
                parameters['word_rnn_dim'], parameters['word_bidirect'] == 1,
                parameters['cap_dim'], parameters['pos_dim'], save_path,
                parameters['num_word'], parameters['num_char'],
                parameters['num_cap'], parameters['num_pos'],
                parameters['num_tag'])
        self.tag_scheme = parameters['tag_scheme']
        self.use_part_of_speech = 'pos_dim' in parameters and parameters[
            'pos_dim'] > 0
        self.dataset = DataSet(parameters)
        self.nlp_parser = NLPParser()
        self.fix = lambda x: x.replace('-LSB-', '[').replace(
            '-RSB-', ']').replace('-LCB-', '{').replace('-RCB-', '}').replace(
                '-LRB-', '(').replace('-RRB-', ')')
示例#7
0
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
    save_path = os.path.join(checkpoint_dir, "model")
    dev_res_path = os.path.join(out_dir, 'dev.res')
    log_path = os.path.join(out_dir, 'train.log')
    config_path = os.path.join(out_dir, 'config.json')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if config['reload']:
        config['load_path'] = save_path
    else:
        config['load_path'] = None

    dataset = DataSet(config)
    config["pattern_config"]['num_word'] = dataset.num_word
    config["pattern_config"]['num_char'] = dataset.num_char
    config['relation_config']['num_word'] = dataset.num_relation
    config['relation_config']['num_char'] = dataset.num_char
    if 'type_config' in config:
        config['type_config']['num_word'] = dataset.num_type
        config['question_config']['num_word'] = dataset.num_word
    if 'topic_config' in config:
        config['topic_config']['num_word'] = dataset.num_char  # this is char-based
    if 'answer_type_config' in config:
        config['answer_type_config']['num_answer_type'] = dataset.num_answer_type
        config['answer_type_config']['num_qword'] = dataset.num_qword
        config['answer_type_config']['num_word'] = dataset.num_word  # FOR ADD ENCODER

    model = BetaRankerModel(config)
示例#8
0
        os.makedirs(out_dir)
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
    save_path = os.path.join(checkpoint_dir, "model")
    dev_res_path = os.path.join(out_dir, 'dev.res')
    log_path = os.path.join(out_dir, 'train.log')
    config_path = os.path.join(out_dir, 'config.json')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if config['reload']:
        config['load_path'] = save_path
    else:
        config['load_path'] = None

    dataset = DataSet(config['dictionaries'],
                      config['max_len'],
                      n_words_source=config['source_vocab_size'],
                      n_word_target=config['target_vocab_size'])
    #config['source_vocab_size'] = dataset.num_source_word
    #config['target_vocab_size'] = dataset.num_target_word

    fout_log = open(log_path, 'a')
    with open(config_path, 'w') as fout:
        print >> fout, json.dumps(config)

    model = Model(config)

    for epoch_index in xrange(config['num_epoch']):
        tic = time()
        lno = 0
        total_loss = 0.
        for data in dataset.train_batch_iterator(config['dataset'],
flags.DEFINE_string("model_name", "gec",
                    "File name used for model checkpoints")

flags.DEFINE_string("source_file", "data/token_lang-8/lang8-train.en",
                    "source train file path")
flags.DEFINE_string("target_file", "data/token_lang-8/lang8-train.gec",
                    "target train file path")
flags.DEFINE_string("source_valid", "data/token_lang-8/lang8-valid.en",
                    "source valid file path")
flags.DEFINE_string("target_valid", "data/token_lang-8/lang8-valid.gec",
                    "target valid file path")

dataSet = DataSet(FLAGS.embedding_size,
                  FLAGS.source_file,
                  FLAGS.target_file,
                  FLAGS.source_valid,
                  FLAGS.target_valid,
                  FLAGS.batch_size,
                  is_first=False)

# 生成训练数据和测试数据
dataSet.gen_train_valid()
vocab_size = len(dataSet.idx_to_word)

print(vocab_size)

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9,
                            allow_growth=True)
config = tf.ConfigProto(log_device_placement=False,
                        allow_soft_placement=True,
                        gpu_options=gpu_options)
flags.DEFINE_integer("epochs", 10, "Maximum # of training epochs")
flags.DEFINE_string("word_embedding_path",
                    "word2vec/blstm_atten/word_embedding.npy",
                    "word embedding numpy store path")
flags.DEFINE_string("data_path", "data/tokens.txt", "raw data path")
flags.DEFINE_integer("steps_per_checkpoint", 100,
                     "Save model checkpoint every this iteration")
flags.DEFINE_string("model_dir", "model/blstm_atten/",
                    "Path to save model checkpoints")
flags.DEFINE_string("model_name", "atec.ckpt",
                    "File name used for model checkpoints")

flags.DEFINE_float("ratio", 0.1, "eval data ratio")

dataSet = DataSet(filename=FLAGS.data_path,
                  embedding_size=FLAGS.embedding_size,
                  model="blstm_atten")

# 生成训练集和测试集
train_data, eval_data = dataSet.gen_train_eval_data()
print("train number: {}".format(len(train_data)))

word_embedding = np.load(FLAGS.word_embedding_path)

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9,
                            allow_growth=True)
config = tf.ConfigProto(log_device_placement=False,
                        allow_soft_placement=True,
                        gpu_options=gpu_options)

with tf.Session(config=config) as sess:
class CoFiltering:
    def __init__(self):
        self.data = DataSet()

    def pearson_sim(self, user1, user2):
        """
        Calculate Pearson-Correlation-Coefficient of user1 & user2.
        """
        user1_array = self.data.matrix[user1 - 1]
        user2_array = self.data.matrix[user2 - 1]
        length = len(user1_array)
        sum1 = sum(user1_array)
        sum2 = sum(user2_array)
        sum_mul = self.multi(user1_array, user2_array)
        sum_x2 = sum([i**2 for i in user1_array])
        sum_y2 = sum([j**2 for j in user2_array])
        num = sum_mul - (float(sum1) * float(sum2) / length)
        den = sqrt((sum_x2 - float(sum1**2) / length) *
                   (sum_y2 - float(sum2**2) / length))
        return num / den

    @staticmethod
    def multi(x, y):
        """
        To get two 1D-arrays' multiply result.
        The two arrays must have the same size.
        :param x: one array.
        :param y: another array.
        :return: multiply result.
        """
        result = 0.
        for i in range(len(x)):
            result += x[i] * y[i]
        return result

    def most_similar(self, user1, top_n=5):
        """
        To find TOP_N most similar users.
        :param user1: user_id, NOT ARRAY, eg. 23
        :param top_n: Just like what "TOP_N" said.
        :return: LIKE "[(most_similar_user_1, score),...(most_similar_user_topN, score)]".
        """
        result_collect = {}
        for user2 in self.data.users:
            if user2 == user1:
                pass
            else:
                try:
                    result = self.pearson_sim(user1, user2)
                    result_collect[user2] = result
                except IndexError:
                    pass

        results_sorted = sorted(result_collect.items(),
                                key=lambda item: item[1],
                                reverse=True)[:top_n]
        print('Most similar users: {}'.format(' | '.join(
            [str(x[0]) for x in results_sorted])))
        return results_sorted

    def predict(self, user, top_n=5, recommend_num=5):
        if user not in self.data.users:
            raise ValueError(
                'Cannot find user "{}", please check.'.format(user))

        results = self.most_similar(user, top_n)
        recommend = []
        for user_id, val in results:
            diff_list = list(self.data.matrix[user_id] -
                             self.data.matrix[user])
            temp = filter(lambda item: item[1] > 0, enumerate(diff_list))
            recommend.extend(temp)
        recommend = sorted(recommend, key=lambda item: item[1], reverse=True)

        movie_list = []
        while True:
            for i in range(1, 6).__reversed__():
                temp_list = filter(lambda x: x[1] == i, recommend)
                temp_list = sorted(temp_list,
                                   key=lambda x: self.data.move_pop_rank[x[0]],
                                   reverse=True)
                for x, y in temp_list:
                    if x + 1 not in movie_list:
                        movie_list.append(x + 1)
                if len(movie_list) >= recommend_num:
                    break
                else:
                    continue
            break
        movie_list = [
            self.data.id2movie(x) for x in movie_list[:recommend_num]
        ]
        print('Recommend movies: {}'.format(' | '.join(movie_list)))
示例#12
0
def train(flags):
    train_data_df = load_data_from_csv(flags.train_csv_file)
    columns = train_data_df.columns.values.tolist()
    hparams = create_hparams(flags)
    hparams.add_hparam("vocab_size", 10000)
    save_hparams(flags.checkpoint_dir, hparams)
    for column in columns[2:]:
        dataset = DataSet(flags.train_data_file, flags.train_csv_file, column,
                          flags.batch_size, flags.vocab_file,
                          flags.max_sent_in_doc, flags.max_word_in_sent)
        eval_dataset = DataSet(flags.eval_data_file, flags.eval_csv_file,
                               column, flags.batch_size, flags.vocab_file,
                               flags.max_sent_in_doc, flags.max_word_in_sent)
        train_graph = tf.Graph()
        eval_graph = tf.Graph()

        with train_graph.as_default():
            train_model = HAN(hparams)
            initializer = tf.global_variables_initializer()

        with eval_graph.as_default():
            eval_hparams = load_hparams(
                flags.checkpoint_dir, {
                    "mode": 'eval',
                    'checkpoint_dir': flags.checkpoint_dir + "/best_dev"
                })
            eval_model = HAN(eval_hparams)

        train_sess = tf.Session(
            graph=train_graph,
            config=get_config_proto(log_device_placement=False))
        try:
            train_model.restore_model(train_sess)
        except:
            print(
                "unable to restore model, initialize model with fresh params")
            train_model.init_model(train_sess, initializer=initializer)
        print("#{0} model starts to train with learning rate {1}, {2}".format(
            column, flags.learning_rate, time.ctime()))

        global_step = train_sess.run(train_model.global_step)
        eval_ppls = []
        best_eval = 1000000000
        for epoch in range(flags.num_train_epoch):
            checkpoint_loss = 0.0,

            for i, (x, y) in enumerate(dataset.get_next()):
                batch_loss, accuracy, summary, global_step = train_model.train_one_batch(
                    train_sess, x, y)
                checkpoint_loss += batch_loss * flags.batch_size
                if global_step == 0:
                    continue

                if global_step % flags.steps_per_stats == 0:
                    summary = tf.Summary()
                    summary.value.add(tag='accuracy', simple_value=accuracy)
                    train_model.summary_writer.add_summary(
                        summary, global_step=global_step)

                    print("# Epoch %d  global step %d batch %d/%d  "
                          "batch loss %.5f accuracy %.2f " %
                          (epoch + 1, global_step, i + 1, flags.batch_size,
                           batch_loss, accuracy))

                if global_step % flags.steps_per_eval == 0:
                    print("# global step {0}, eval model at {1}".format(
                        global_step, time.ctime()))
                    checkpoint_path = train_model.save_model(train_sess)
                    with tf.Session(
                            graph=eval_graph,
                            config=get_config_proto(
                                log_device_placement=False)) as eval_sess:
                        eval_model.saver.restore(eval_sess, checkpoint_path)
                        _, eval_ppl = train_eval(eval_model, eval_sess,
                                                 eval_dataset)
                        if eval_ppl < best_eval:
                            eval_model.save_model(eval_sess)
                            best_eval = eval_ppl
                    eval_ppls.append(eval_ppl)
                    if early_stop(eval_ppls):
                        print("# No loss decrease, early stop")
                        print("# Best perplexity {0}".format(best_eval))
                        exit(0)

            print("# Finsh epoch {1}, global step {0}".format(
                global_step, epoch + 1))
示例#13
0
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
    save_path = os.path.join(checkpoint_dir, "model")
    dev_res_path = os.path.join(out_dir, 'dev.res')
    log_path = os.path.join(out_dir, 'train.log')
    config_path = os.path.join(out_dir, 'config.json')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if config['reload']:
        config['load_path'] = save_path
    else:
        config['load_path'] = None

    dataset = DataSet(config)
    config["question_config"]['num_word'] = dataset.num_word
    config["question_config"]['num_char'] = dataset.num_char
    config['relation_config']['num_word'] = dataset.num_relation
    config['relation_config']['num_char'] = dataset.num_char

    model = RelationMatcherModel(config)

    fout_log = open(log_path, 'a')
    with open(config_path, 'w') as fout:
        print >> fout, json.dumps(config)

    best_p_at_1 = 0
    if "fn_dev" in config:
        p_at_1, average_rank, num_avg_candidates, eval_info = \
            evaluate(dataset, model, config['fn_dev'], dev_res_path)
示例#14
0
    parameters['cap_dim'] = FLAGS.cap_dim
    parameters['pos_dim'] = FLAGS.pos_dim
    parameters['dropout_keep_prob'] = FLAGS.dropout_keep_prob
    if FLAGS.reload == 1:
        parameters['load_path'] = save_path
    else:
        parameters['load_path'] = None
    parameters['tag_scheme'] = FLAGS.tag_scheme
    parameters['fn_word'] = os.path.abspath(
        os.path.join(os.path.curdir, FLAGS.fn_word))
    parameters['fn_char'] = os.path.abspath(
        os.path.join(os.path.curdir, FLAGS.fn_char))
    if FLAGS.pos_dim:
        parameters['fn_pos'] = os.path.abspath(
            os.path.join(os.path.curdir, FLAGS.fn_pos))
    dataset = DataSet(parameters)
    parameters['num_word'] = dataset.num_word
    parameters['num_char'] = dataset.num_char
    parameters['num_cap'] = dataset.num_cap
    parameters['num_tag'] = dataset.num_tag
    parameters['num_pos'] = dataset.num_pos

    model = DeepCRF(FLAGS.max_sentence_len, FLAGS.max_word_len, FLAGS.char_dim,
                    FLAGS.char_rnn_dim, FLAGS.char_bidirect == 1,
                    FLAGS.word_dim, FLAGS.word_rnn_dim,
                    FLAGS.word_bidirect == 1, FLAGS.cap_dim, FLAGS.pos_dim,
                    save_path if FLAGS.reload else None, dataset.num_word,
                    dataset.num_char, dataset.num_cap, dataset.num_pos,
                    dataset.num_tag)
    fout_log = open(log_path, 'a')
 def __init__(self):
     self.data = DataSet()
示例#16
0
文件: evaluete.py 项目: mindis/KBQA
    res_info += "Number of test case: {} \nAverage rank: {}\nAverage number of candidates: {}"\
        .format(count, average_rank, average_candidate_count)
    if res_file:
        print >> res_file, res_info
    print res_info
    return p_at_k[0], average_rank, average_candidate_count, res_info


if __name__ == '__main__':
    from data_helper import DataSet

    from model import RelationMatcherModel
    dir_path = sys.argv[1]  # model dir path
    fn_dev = sys.argv[2]  # test file path
    if len(sys.argv) == 3:
        res_name = 'test.res'
    else:
        res_name = sys.argv[3]
    dir_path = os.path.abspath(dir_path)
    checkpoint_dir = os.path.join(dir_path, "checkpoints")
    save_path = os.path.join(checkpoint_dir, "model")

    config_path = os.path.join(dir_path, 'config.json')
    parameters = json.load(open(config_path))
    parameters['load_path'] = save_path

    dataset = DataSet(parameters)
    model = RelationMatcherModel(parameters)
    fn_res = os.path.join(dir_path, res_name)
    evaluate(dataset, model, fn_dev, fn_res)
示例#17
0
文件: predict.py 项目: mindis/KBQA
class EntityMentionTagger(object):
    def __init__(self, dir_path):

        checkpoint_dir = os.path.abspath(os.path.join(dir_path, "checkpoints"))
        save_path = os.path.join(checkpoint_dir, "model")
        config_path = os.path.join(dir_path, 'config.json')
        parameters = json.load(open(config_path))
        with tf.Graph().as_default():
            self.model = DeepCRF(
                parameters['max_sentence_len'], parameters['max_word_len'],
                parameters['char_dim'], parameters['char_rnn_dim'],
                parameters['char_bidirect'] == 1, parameters['word_dim'],
                parameters['word_rnn_dim'], parameters['word_bidirect'] == 1,
                parameters['cap_dim'], parameters['pos_dim'], save_path,
                parameters['num_word'], parameters['num_char'],
                parameters['num_cap'], parameters['num_pos'],
                parameters['num_tag'])
        self.tag_scheme = parameters['tag_scheme']
        self.use_part_of_speech = 'pos_dim' in parameters and parameters[
            'pos_dim'] > 0
        self.dataset = DataSet(parameters)
        self.nlp_parser = NLPParser()
        self.fix = lambda x: x.replace('-LSB-', '[').replace(
            '-RSB-', ']').replace('-LCB-', '{').replace('-RCB-', '}').replace(
                '-LRB-', '(').replace('-RRB-', ')')

    def get_pos_tag(self, sentence):
        sentence = naive_split(sentence)
        tokens, poss = self.nlp_parser.tag_pos(' '.join(sentence))
        tokens = [self.fix(t) for t in tokens]
        return tokens, poss

    def tag(self, sentence):
        if self.use_part_of_speech:
            sentence, poss = self.get_pos_tag(sentence)
        else:
            sentence = naive_split(sentence)
            poss = None
        data = self.dataset.create_model_input(sentence, poss)
        viterbi_sequences, _ = self.model.predict(
            data['sentence_lengths'],
            data['word_ids'],
            data['char_for_ids'],
            data['char_rev_ids'],
            data['word_lengths'],
            data['cap_ids'],
            data['pos_ids'],
        )
        viterbi_sequence = viterbi_sequences[0]
        seq_len = data['sentence_lengths'][0]

        words = data['words'][0][:seq_len]
        mentions, pred_tag_sequence = self.dataset.get_mention_from_words(
            words, viterbi_sequence
        )  # 'mentions' contains start index of each mention
        mention_to_likelihood = dict()
        likelihood = self.get_sequence_likelihood(data, viterbi_sequences)[0]

        for m in mentions:
            mention_to_likelihood[m] = likelihood
        res = dict()
        res['sentence'] = ' '.join(sentence)
        res['mentions'] = mention_to_likelihood
        if poss:
            res['pos'] = poss
        return res

    def tag_top2(self, sentence):
        if self.use_part_of_speech:
            sentence, poss = self.get_pos_tag(sentence)
        else:
            sentence = naive_split(sentence)
            poss = None
        data = self.dataset.create_model_input(sentence, poss)

        viterbi_sequences, scores = self.model.predict_top_k(
            data['sentence_lengths'],
            data['word_ids'],
            data['char_for_ids'],
            data['char_rev_ids'],
            data['word_lengths'],
            data['cap_ids'],
            data['pos_ids'],
        )
        seq_len = data['sentence_lengths'][0]
        words = data['words'][0][:seq_len]
        mention_to_likelihood = dict()
        for k in range(2):
            if k == 1 and scores[0][1] * 1.0 / scores[0][0] < 0.95:
                break
            viterbi_sequence_ = viterbi_sequences[0][k]
            likelihood = self.get_sequence_likelihood(data,
                                                      [viterbi_sequence_])[0]

            pred_entities, pred_tag_sequence = self.dataset.get_mention_from_words(
                words, viterbi_sequence_)
            for e in pred_entities:
                if e not in mention_to_likelihood:
                    mention_to_likelihood[e] = likelihood

        res = dict()
        res['mentions'] = mention_to_likelihood
        res['sentence'] = ' '.join(sentence)
        if poss:
            res['pos'] = poss
        return res

    def get_sequence_likelihood(self, batch_data, batch_sequence):
        batch_sequence = [
            self.dataset.pad_xx(s, self.dataset.tag_padding)
            for s in batch_sequence
        ]
        scores = self.model.get_likelihood(
            batch_sequence,
            batch_data['sentence_lengths'],
            batch_data['word_ids'],
            batch_data['char_for_ids'],
            batch_data['char_rev_ids'],
            batch_data['word_lengths'],
            batch_data['cap_ids'],
            batch_data['pos_ids'],
        )
        return scores.tolist()

    def get_mention_likelihood(self, question, mention):
        if self.use_part_of_speech:
            sentence, poss = self.get_pos_tag(question)
        else:
            sentence = naive_split(question)
            poss = None
        mention = mention.split()
        data = self.dataset.create_model_input(sentence, poss)
        start = find_word(sentence, mention)
        end = start + len(mention)
        tag_ids = self.dataset.create_tag_sequence(start, end, len(sentence),
                                                   self.tag_scheme)
        scores = self.model.get_likelihood(
            tag_ids,
            data['sentence_lengths'],
            data['word_ids'],
            data['char_for_ids'],
            data['char_rev_ids'],
            data['word_lengths'],
            data['cap_ids'],
            data['pos_ids'],
        )
        return question, scores.tolist()[0]