示例#1
0
 def entity_constraint_extraction(self, el_result, vb=0):
     ec_dict = {}  # <el_item, [anchor predicate]>
     total_size = 0
     for el_item in el_result:
         mid = el_item.entity.id
         raw_anchor_pred_list = self.driver.query_pred_given_object(mid)
         anchor_pred_list = [
             pred for pred in raw_anchor_pred_list
             if not (pred.startswith('common') or pred.startswith('type'))
         ]  # filter common / type predicates
         ec_dict[el_item] = anchor_pred_list
         total_size += len(anchor_pred_list)
     if vb == 1:
         for el_item, p_list in ec_dict.items():
             LogInfo.begin_track('%d anchors for obj = %s (%s): ',
                                 len(p_list),
                                 el_item.entity.id.encode('utf-8'),
                                 el_item.name.encode('utf-8'))
             for p in p_list:
                 LogInfo.logs(p)
             LogInfo.end_track()
     LogInfo.logs(
         'In total %d candidate objects and ' +
         '%d <obj_mid, anchor_pred> extracted.', len(ec_dict), total_size)
     return ec_dict, total_size
示例#2
0
    def forward(self,
                path_wd_hidden,
                path_kb_hidden,
                path_len,
                focus_wd_hidden,
                focus_kb_hidden,
                reuse=None):
        LogInfo.begin_track('SkBiRNNModule forward: ')

        with tf.variable_scope('SkBiRNNModule', reuse=reuse):
            if self.data_source == 'kb':
                use_path_hidden = path_kb_hidden
                use_focus_hidden = focus_kb_hidden
            elif self.data_source == 'word':
                use_path_hidden = path_wd_hidden
                use_focus_hidden = focus_wd_hidden
            else:
                use_path_hidden = tf.concat([path_kb_hidden, path_wd_hidden],
                                            axis=-1,
                                            name='use_path_hidden')
                # (batch, path_max_len, dim_item_hidden + dim_kb_hidden)
                use_focus_hidden = tf.concat(
                    [focus_kb_hidden, focus_wd_hidden],
                    axis=-1,
                    name='use_focus_hidden')
                # (batch, dim_item_hidden + dim_kb_hidden)

            use_path_emb_input = tf.concat(
                [tf.expand_dims(use_focus_hidden, axis=1), use_path_hidden],
                axis=1,
                name='use_path_emb_input'
            )  # (batch, path_max_len + 1, dim_use)
            show_tensor(use_path_emb_input)
            use_path_len = path_len + 1
            stamps = self.path_max_len + 1
            birnn_inputs = tf.unstack(use_path_emb_input,
                                      num=stamps,
                                      axis=1,
                                      name='birnn_inputs')
            encoder_output = self.rnn_encoder.encode(
                inputs=birnn_inputs, sequence_length=use_path_len, reuse=reuse)
            rnn_outputs = tf.stack(
                encoder_output.outputs, axis=1,
                name='rnn_outputs')  # (batch, path_max_len + 1, dim_sk_hidden)

            # Since we are in the BiRNN mode, we are simply taking average.

            sum_sk_hidden = tf.reduce_sum(
                rnn_outputs, axis=1,
                name='sum_sk_hidden')  # (batch, dim_sk_hidden)
            use_path_len_mat = tf.cast(
                tf.expand_dims(use_path_len, axis=1),
                dtype=tf.float32,
                name='use_path_len_mat')  # (batch, 1) as float32
            sk_hidden = tf.div(sum_sk_hidden,
                               use_path_len_mat,
                               name='sk_hidden')  # (batch, dim_sk_hidden)

        LogInfo.end_track()
        return sk_hidden
示例#3
0
 def print_att_matrix(word_surf_list, path_surf_list, att_mat):
     """
     Given surfaces of words in Q, and the indicator of each skeleton,
     print the attention matrix in a nice format.
     """
     word_sz = len(word_surf_list)
     path_sz = len(path_surf_list)
     header = [' ' * 10]
     for col_idx in range(att_mat.shape[1]):
         if col_idx < path_sz:
             header.append('%9s' % path_surf_list[col_idx])
         else:
             header.append('  <EMPTY>')
         if col_idx == path_sz - 1:
             header.append('|')
     LogInfo.logs(' '.join(header))
     for row_idx, att_row in enumerate(att_mat):
         # scan each row of the raw attention matrix
         show_str_list = []
         if row_idx >= word_sz:
             wd = '<EMPTY>'
         else:
             wd = word_surf_list[row_idx]
         show_str_list.append('%10s' % wd)
         for col_idx, att_val in enumerate(att_row):
             show_str_list.append('%9.6f' % att_val)
             if col_idx == path_sz - 1:
                 show_str_list.append('|')
         show_str = ' '.join(show_str_list)
         LogInfo.logs(show_str)
         if row_idx == word_sz - 1:
             LogInfo.logs(
                 '-' * len(show_str))  # split useful and non-useful parts
     LogInfo.end_track()
示例#4
0
    def query_sparql(self, sparql_query):
        key = self.shrink_query(sparql_query)
        hit = key in self.sparql_dict
        show_request = (
            not hit or self.vb >= 2
        )  # going to perform a real query, or just want to show more
        if show_request:
            LogInfo.begin_track('[%s] SPARQL Request:', show_time())
            LogInfo.logs(key)
        if hit and self.vb >= 1:
            if not show_request:
                LogInfo.logs('[%s] SPARQL hit!', show_time())
            else:
                LogInfo.logs('SPARQL hit!')
        if hit:
            query_ret = self.sparql_dict[key]
        else:
            query_ret = self.kernel_query(key=key)

            # if self.lock is not None:
            #     self.lock.acquire()
            if query_ret is not None:  # ignore schemas returning None
                self.sparql_dict[key] = query_ret
                self.sparql_buffer.append((key, query_ret))
            # if self.lock is not None:
            #     self.lock.release()

        if show_request and self.vb >= 3:
            LogInfo.logs(query_ret)
        if show_request:
            LogInfo.end_track()

        final_query_ret = query_ret or []
        return final_query_ret  # won't send None back, but use empty list as instead
示例#5
0
def show_el_detail(sc):
    el_feats_concat, el_raw_score = [
        sc.run_info[k].tolist() for k in ('el_feats_concat', 'el_raw_score')
    ]
    el_len = sc.input_np_dict['el_len']
    gl_tup_list = []
    for category, gl_data, pred_seq in sc.raw_paths:
        if category not in ('Entity', 'Main'):
            continue
        if category == 'Main':
            tp = get_domain(pred_seq[0])
        else:
            tp = get_range(pred_seq[-1])
        gl_tup_list.append((gl_data, tp))
    assert len(gl_tup_list) == el_len
    for el_idx in range(el_len):
        gl_data, tp = gl_tup_list[el_idx]
        LogInfo.begin_track('Entity %d / %d:', el_idx + 1, el_len)
        LogInfo.logs(gl_data.display())
        LogInfo.logs(
            'Prominent type: [%s]  <---  (Ignore this line if type is not used.)',
            tp)
        LogInfo.logs('raw_score = %.6f', el_raw_score[el_idx])
        show_str = '  '.join(['%6.3f' % x for x in el_feats_concat[el_idx]])
        LogInfo.logs('el_feats_concat = %s', show_str)
        LogInfo.end_track()
示例#6
0
def load_annotations_bio(word_dict, q_max_len):
    """ Read annotation, convert to B,I,O format, and store into numpy array """
    LogInfo.begin_track('Load SimpQ-mention annotation from [%s]:', anno_fp)
    raw_tup_list = []  # [(v, v_len, tag)]
    with codecs.open(anno_fp, 'r', 'utf-8') as br:
        for line_idx, line in enumerate(br.readlines()):
            spt = line.strip().split('\t')
            q_idx, st, ed = [int(x) for x in spt[:3]]
            jac = float(spt[3])
            if jac != 1.0:
                continue  # only pick the most accurate sentences
            tok_list = spt[-1].lower().split(' ')
            v_len = len(tok_list)
            v = [word_dict[tok]
                 for tok in tok_list]  # TODO: make sure all word exists
            tag = [2] * st + [
                0
            ] + [1] * (ed - st - 1) + [2] * (v_len - ed)  # 0: B, 1: I, 2: O
            # if line_idx < 10:
            #     LogInfo.begin_track('Check case-%d: ', line_idx)
            #     LogInfo.logs('tok_list: %s', tok_list)
            #     LogInfo.logs('v: %s', v)
            #     LogInfo.logs('tag: %s', tag)
            #     LogInfo.end_track()
            assert len(tag) == len(v)
            raw_tup_list.append((v, v_len, tag))
    q_size = len(raw_tup_list)
    v_len_list = [tup[1] for tup in raw_tup_list]
    LogInfo.logs('%d high-quality annotation loaded.', q_size)
    LogInfo.logs('maximum length = %d (%.6f on avg)', np.max(v_len_list),
                 np.mean(v_len_list))
    for pos in (25, 50, 75, 90, 95, 99, 99.9):
        LogInfo.logs('Percentile = %.1f%%: %.6f', pos,
                     np.percentile(v_len_list, pos))

    filt_tup_list = filter(lambda _tup: _tup[1] <= q_max_len, raw_tup_list)
    LogInfo.logs('%d / %d sentence filtered by [q_max_len=%d].',
                 len(filt_tup_list), q_size, q_max_len)

    # idx = 0
    for v, _, tag in filt_tup_list:
        v += [0] * (q_max_len - len(v))
        tag += [2] * (q_max_len - len(tag))
        # if idx < 10:
        #     LogInfo.begin_track('Check formed case-%d ', idx)
        #     LogInfo.logs('v: %s', v)
        #     LogInfo.logs('tag: %s', tag)
        #     LogInfo.end_track()
        # idx += 1
    v_list, v_len_list, tag_list = [[tup[i] for tup in filt_tup_list]
                                    for i in range(3)]
    np_data_list = [
        np.array(v_list, dtype='int32'),  # (ds, q_max_len)
        np.array(v_len_list, dtype='int32'),  # (ds, )
        np.array(tag_list, dtype='int32')  # (ds, num_classes)
    ]
    for idx, np_data in enumerate(np_data_list):
        LogInfo.logs('np-%d: %s', idx, np_data.shape)
    LogInfo.end_track()
    return np_data_list
示例#7
0
def show_el_detail_without_type(sc, qa, schema_dataset):
    # el_final_feats, el_raw_score = [
    #     sc.run_info[k].tolist() for k in ('el_final_feats', 'el_score')
    # ]
    assert qa is not None
    assert schema_dataset is not None

    el_final_feats = sc.run_info['el_final_feats'].tolist()
    LogInfo.logs('el_final_feats = [%s]',
                 ' '.join(['%6.3f' % x for x in el_final_feats]))

    el_mask = sc.input_np_dict['el_mask']
    path_size = sc.input_np_dict['path_size']
    el_indv_feats = sc.input_np_dict['el_indv_feats']
    gl_list = []
    for category, gl_data, pred_seq in sc.raw_paths:
        gl_list.append(gl_data)
    assert path_size == len(gl_list)

    for el_idx in range(path_size):
        msk = el_mask[el_idx]
        gl_data = gl_list[el_idx]
        LogInfo.begin_track('Entity %d / %d:', el_idx + 1, path_size)

        LogInfo.logs(gl_data.display())
        if msk == 0.:
            LogInfo.logs('[Mask = 0, IGNORED.]')
        else:
            local_feats = el_indv_feats[el_idx]
            LogInfo.logs('local_feats = [%s]',
                         '  '.join(['%6.3f' % x for x in local_feats]))
        LogInfo.end_track()
示例#8
0
    def post_process(self, eval_dl, detail_fp):
        total_pred = 0  # total number of predicted chunks
        total_gold = 0  # total number of gold chunks
        total_correct = 0
        avg_log_lik = 0.

        data_size = 0
        for v_len, gold_tag, pred_tag, log_lik in zip(*self.eval_detail_list):
            data_size += 1
            gold_tag = gold_tag[:v_len]
            pred_tag = pred_tag[:v_len]
            gold_chunk_set = self.produce_chunk(gold_tag)
            pred_chunk_set = self.produce_chunk(pred_tag)
            total_pred += len(pred_chunk_set)
            total_gold += len(gold_chunk_set)
            total_correct += len(pred_chunk_set & gold_chunk_set)
            avg_log_lik += log_lik
            if data_size <= 5:
                LogInfo.begin_track('Check case-%d:', data_size)
                LogInfo.logs('seq_len = %d', v_len)
                LogInfo.logs('Gold: %s --> %s', gold_tag.tolist(),
                             gold_chunk_set)
                LogInfo.logs('Pred: %s --> %s', pred_tag.tolist(),
                             pred_chunk_set)
                LogInfo.logs('log_lik = %.6f', log_lik)
                LogInfo.end_track()
        p = 1. * total_correct / total_pred
        r = 1. * total_correct / total_gold
        f1 = 2. * p * r / (p + r) if (p + r) > 0. else 0.
        avg_log_lik /= data_size
        return f1, avg_log_lik
示例#9
0
    def load_cands(self):
        if len(self.np_data_list) > 0 and \
                self.q_cand_dict is not None and \
                self.q_words_dict is not None:
            return
        if not os.path.isfile(self.dump_fp):
            self.prepare_all_data()
            return
        LogInfo.begin_track('Loading candidates & np_data from [%s] ...', self.dump_fp)
        with open(self.dump_fp, 'rb') as br:
            self.q_list = cPickle.load(br)
            LogInfo.logs('q_list loaded for %d questions.', len(self.q_list))
            self.q_words_dict = cPickle.load(br)
            LogInfo.logs('q_words_dict loaded for %d questions.', len(self.q_words_dict))
            self.q_cand_dict = cPickle.load(br)
            LogInfo.logs('q_cand_dict loaded.')

            cand_size_dist = np.array([len(v) for v in self.q_cand_dict.values()])
            LogInfo.begin_track('Show candidate size distribution:')
            for pos in (25, 50, 75, 90, 95, 99, 99.9, 100):
                LogInfo.logs('Percentile = %.1f%%: %.6f', pos, np.percentile(cand_size_dist, pos))
            LogInfo.end_track()

            for data_idx in range(self.array_num):
                np_data = np.load(br)
                self.np_data_list.append(np_data)
                LogInfo.logs('np-data-%d loaded: %s', data_idx, np_data.shape)
        LogInfo.end_track()
示例#10
0
def type_filtering(el_result,
                   tl_result,
                   sparql_driver,
                   is_type_extend=True,
                   vb=0):
    if vb >= 1:
        LogInfo.begin_track('Type Filtering:')
    relevant_preds = set([])
    for el in el_result:
        mid = el.entity.id
        local_relevant_preds = collect_relevant_predicate(mid, sparql_driver)
        relevant_preds |= local_relevant_preds
    if vb >= 1:
        LogInfo.logs('%d relevant predicates collected.', len(relevant_preds))

    topical_consistent_types = prepare_topical_consistent_types(
        relevant_pred_set=relevant_preds,
        is_type_extended=is_type_extend,
        vb=vb)
    filt_tl_result = filter(
        lambda tl: tl.entity.id in topical_consistent_types, tl_result)
    LogInfo.logs('Type Filter: %d / %d types are kept.', len(filt_tl_result),
                 len(tl_result))
    if vb >= 1:
        LogInfo.end_track()

    return filt_tl_result
示例#11
0
def construct_gather_linkings(el_result, tl_result, tml_result,
                              tml_comp_result):
    # Put all E/T/Tm linkings together.
    gather_linkings = []
    for el in el_result:
        assert hasattr(el, 'link_feat')
        disp = 'E: [%d, %d) %s (%s) %.6f' % (
            el.tokens[0].index, el.tokens[-1].index + 1,
            el.entity.id.encode('utf-8'), el.name.encode('utf-8'),
            el.surface_score)
        gather_linkings.append(LinkData(el, 'Entity', '==', disp,
                                        el.link_feat))
    for tl in tl_result:
        disp = 'T: [%d, %d) %s (%s) %.6f' % (
            tl.tokens[0].index, tl.tokens[-1].index + 1,
            tl.entity.id.encode('utf-8'), tl.name.encode('utf-8'),
            tl.surface_score)
        gather_linkings.append(LinkData(tl, 'Type', '==', disp, []))
    for tml, comp in zip(tml_result, tml_comp_result):
        disp = 'Tm: [%d, %d) %s %s %.6f' % (
            tml.tokens[0].index, tml.tokens[-1].index + 1, comp,
            tml.entity.sparql_name().encode('utf-8'), tml.surface_score)
        gather_linkings.append(LinkData(tml, 'Time', comp, disp, []))
    sz = len(gather_linkings)
    LogInfo.begin_track('%d E + %d T + %d Tm = %d links.', len(el_result),
                        len(tl_result), len(tml_result), sz)
    for link_data in gather_linkings:
        LogInfo.logs(link_data.display)
    LogInfo.end_track()
    return gather_linkings
示例#12
0
def pick_one_search(spec_linkings, conflict_matrix, tag_set, av_combs, spec):
    """
    Work for T/Tm/Ord, since only one of them can be selected, no need for DFS.
    """
    assert spec in ('T', 'Tm', 'Ord')
    LogInfo.begin_track('Searching at %s level ...', spec)
    spec_available_combs = []
    for gl_data_indices, tag_elements, visit_arr in av_combs:
        for gl_data in spec_linkings:
            gl_pos = gl_data.gl_pos
            if visit_arr[gl_pos] != 0:  # cannot be visited due to conflict
                continue
            new_visit_arr = list(visit_arr)  # new state after applying types
            for conf_idx in conflict_matrix[gl_pos]:
                new_visit_arr[conf_idx] += 1
            if spec in ('Tm', 'Ord'):
                tag_elem = spec
            else:
                tag_elem = 'T:%s' % gl_data.value
            new_gl_data_indices = list(gl_data_indices) + [gl_pos]
            new_tag_elements = list(tag_elements) + [tag_elem]
            tag = '|'.join(new_tag_elements)
            if tag in tag_set:
                if vb >= 1:
                    LogInfo.logs(tag)
                spec_available_combs.append(
                    (new_gl_data_indices, new_tag_elements, new_visit_arr))
    LogInfo.end_track()
    return spec_available_combs
示例#13
0
def main():
    qa_list = load_webq()
    yih_ret_fp = 'codalab/WebQ/acl2015-msr-stagg/test_predict.txt'
    yih_ret_dict = {}
    with codecs.open(yih_ret_fp, 'r', 'utf-8') as br:
        for line in br.readlines():
            k, v = line.strip().split('\t')
            yih_ret_dict[k] = v
    LogInfo.logs('Yih result collected.')

    exp_tup_list = [('180514_strict/all__full__180508_K03_Fhalf__depSimulate/'
                     'NFix-20__wUpd_RH_qwOnly_compact__b32__fbpFalse',
                     '180508_K03_Fhalf', 10),
                    ('180516_strict/all__full__180508_K03_Fhalf__Lemmatize/'
                     'NFix-20__wUpd_RH_qwOnly_compact__b32__fbpFalse',
                     '180508_K03_Fhalf', 15),
                    ('180516_strict/all__full__180508_K03_Fhalf__Lemmatize/'
                     'NFix-20__wUpd_BH_qwOnly_compact__b32__fbpFalse',
                     '180508_K03_Fhalf', 12)]
    for exp_suf, data_suf, best_epoch in exp_tup_list:
        exp_dir = 'runnings/WebQ/' + exp_suf
        data_dir = 'runnings/candgen_WebQ/' + data_suf
        LogInfo.begin_track('Dealing with [%s], epoch = %03d:', exp_suf,
                            best_epoch)
        work(exp_dir=exp_dir,
             data_dir=data_dir,
             best_epoch=best_epoch,
             qa_list=qa_list,
             yih_ret_dict=yih_ret_dict)
        LogInfo.end_track()
示例#14
0
    def get_gradient_tf_list(self, score_tf):
        LogInfo.begin_track('LambdaRank genearating gradients ... ')
        grad_tf_list = []  # the return value

        scan = 0
        for var in tf.global_variables():
            scan += 1
            LogInfo.begin_track('Variable %d / %d %s: ', scan,
                                len(tf.global_variables()),
                                var.get_shape().as_list())
            per_row_grad_tf_list = []
            for row_idx in range(self.batch_size):
                LogInfo.begin_track('row_idx = %d / %d: ', row_idx + 1,
                                    self.batch_size)
                local_grad_tf_list = []
                for item_idx in range(self.list_len):
                    if (item_idx + 1) % 50 == 0:
                        LogInfo.logs('item_idx = %d / %d', item_idx + 1,
                                     self.list_len)
                    local_grad_tf = tf.gradients(score_tf[row_idx, item_idx],
                                                 var)[0]  # ("var_shape", )
                    local_grad_tf_list.append(local_grad_tf)
                per_row_grad_tf = tf.stack(local_grad_tf_list, axis=0)
                per_row_grad_tf_list.append(per_row_grad_tf)
                # per_row_grad_tf: (list_len, "var_shape")
                LogInfo.end_track()
            grad_tf = tf.stack(per_row_grad_tf_list, axis=0)
            grad_tf_list.append(grad_tf)
            LogInfo.logs('grad_tf: %s', grad_tf.get_shape().as_list())
            # grad_tf: (batch_size, list_len, "var_shape")
            LogInfo.end_track()
        return grad_tf_list
示例#15
0
def analyze_output(data_dir, sort_item):
    """ Check the rank distribution in a global perspective """
    rank_matrix = [
        [], [], [], []
    ]  # focus on 4 tiers: F1 = 1.0, F1 >= 0.5, F1 >= 0.1, F1 > 0
    fp = '%s/lexicon_validate/srt_output.%s.txt' % (data_dir, sort_item)
    with codecs.open(fp, 'r', 'utf-8') as br:
        for line in br.readlines():
            spt = line.strip().split('\t')
            f1 = float(spt[-2])
            rank = int(spt[-1])
            if rank == -1:
                rank = 2111222333
            if f1 == 1.0:
                rank_matrix[0].append(rank)
            if f1 >= 0.5:
                rank_matrix[1].append(rank)
            if f1 >= 0.1:
                rank_matrix[2].append(rank)
            if f1 >= 1e-6:
                rank_matrix[3].append(rank)
    for ths, rank_list in zip((1.0, 0.5, 0.1, 1e-6), rank_matrix):
        LogInfo.begin_track('Show stat for F1 >= %.6f:', ths)
        rank_list = np.array(rank_list)
        case_size = len(rank_list)
        LogInfo.logs('Total cases = %d.', case_size)
        LogInfo.logs('MRR = %.6f', np.mean(1. / rank_list))
        for pos in (50, 60, 70, 80, 90, 95, 99, 99.9, 100):
            LogInfo.logs('Percentile = %.1f%%: %.6f', pos,
                         np.percentile(rank_list, pos))
        LogInfo.end_track()
示例#16
0
    def load_all_data(self):
        if len(self.smart_q_cand_dict) > 0:  # already loaded
            return
        if not os.path.isfile(self.dump_fp):  # no dump, read from txt
            self.load_smart_schemas_from_txt()
        else:
            self.load_smart_schemas_from_pickle()

        LogInfo.begin_track('Meta statistics:')

        LogInfo.logs('Total questions = %d', len(self.q_idx_list))
        LogInfo.logs('T / v / t questions = %s',
                     [len(lst) for lst in self.spt_q_idx_lists])
        LogInfo.logs(
            'Active Word / Mid / Path = %d / %d / %d (with PAD, START, UNK)',
            len(self.active_dicts['word']), len(self.active_dicts['mid']),
            len(self.active_dicts['path']))
        LogInfo.logs('path_max_size = %d, qw_max_len = %d, pw_max_len = %d.',
                     self.path_max_size, self.qw_max_len, self.pw_max_len)

        cand_size_dist = np.array(
            [len(v) for v in self.smart_q_cand_dict.values()])
        LogInfo.logs('Total schemas = %d, avg = %.6f.', np.sum(cand_size_dist),
                     np.mean(cand_size_dist))
        for pos in (25, 50, 75, 90, 95, 99, 99.9, 100):
            LogInfo.logs('cand_size @ %.1f%%: %.6f', pos,
                         np.percentile(cand_size_dist, pos))

        qlen_dist = np.array([len(qa['tokens']) for qa in self.qa_list])
        LogInfo.logs('Avg question length = %.6f.', np.mean(qlen_dist))
        for pos in (25, 50, 75, 90, 95, 99, 99.9, 100):
            LogInfo.logs('question_len @ %.1f%%: %.6f', pos,
                         np.percentile(qlen_dist, pos))

        LogInfo.end_track()
示例#17
0
 def load_smart_cands(self):
     if self.smart_q_cand_dict is not None:  # already loaded
         return
     if not os.path.isfile(self.dump_fp):  # no dump, read from txt
         self.load_smart_schemas_from_txt()
     else:
         LogInfo.begin_track('Loading smart_candidates from [%s] ...',
                             self.dump_fp)
         with open(self.dump_fp, 'rb') as br:
             LogInfo.begin_track('Loading smart_q_cand_dict ... ')
             self.smart_q_cand_dict = cPickle.load(br)
             LogInfo.logs('Candidates for %d questions loaded.',
                          len(self.smart_q_cand_dict))
             cand_size_dist = np.array(
                 [len(v) for v in self.smart_q_cand_dict.values()])
             LogInfo.logs('Total schemas = %d, avg = %.6f.',
                          np.sum(cand_size_dist), np.mean(cand_size_dist))
             for pos in (25, 50, 75, 90, 95, 99, 99.9, 100):
                 LogInfo.logs('Percentile = %.1f%%: %.6f', pos,
                              np.percentile(cand_size_dist, pos))
             LogInfo.end_track()
             self.path_idx_dict = cPickle.load(br)
             self.entity_idx_dict = cPickle.load(br)
             self.type_idx_dict = cPickle.load(br)
             LogInfo.logs('Active E/T/Path dict loaded.')
             self.pw_voc_inputs = cPickle.load(br)  # (path_voc, pw_max_len)
             self.pw_voc_length = cPickle.load(br)  # (path_voc,)
             self.pw_voc_domain = cPickle.load(br)  # (path_voc,)
             self.entity_type_matrix = cPickle.load(
                 br)  # (entity_voc, type_voc)
             self.pw_max_len = self.pw_voc_inputs.shape[1]
             LogInfo.logs('path word & entity_type lookup tables loaded.')
         self.q_idx_list = sorted(self.smart_q_cand_dict.keys())
         LogInfo.end_track()  # end of loading
     self.meta_stat()  # show meta statistics
示例#18
0
 def batch_schema_f1_query(self, q_id, states, level):
     """
     perform F1 query for each schema in the state.
     :param q_id: WebQ-xxx / CompQ-xxx
     :param states: [(schema, visit_arr)]
     :param level: coarse / typed / timed / ordinal
     :return: filtered states where each schema returns at least one answer.
     """
     Tt.start('%s_F1' % level)
     LogInfo.begin_track('Calculating F1 for %d %s schemas:', len(states), level)
     for idx, (sc, _) in enumerate(states):
         if idx % 100 == 0:
             LogInfo.logs('Current: %d / %d', idx, len(states))
         sparql_str = sc.build_sparql(simple_time_match=self.simple_time_match)
         tm_comp, tm_value, ord_comp, ord_rank, agg = sc.build_aux_for_sparql()
         allow_forever = self.allow_forever if tm_comp != 'None' else ''
         # won't specific forever if no time constraints
         q_sc_key = '|'.join([q_id, sparql_str,
                              tm_comp, tm_value, allow_forever,
                              ord_comp, ord_rank, agg])
         if self.vb >= 2:
             LogInfo.begin_track('Checking schema %d / %d:', idx, len(states))
             LogInfo.logs(sc.disp_raw_path())
             LogInfo.logs('var_types: %s', sc.var_types)
             LogInfo.logs(sparql_str)
         Tt.start('query_q_sc_stat')
         sc.ans_size, sc.p, sc.r, sc.f1 = self.query_srv.query_q_sc_stat(q_sc_key)
         Tt.record('query_q_sc_stat')
         if self.vb >= 2:
             LogInfo.logs('Answers = %d, P = %.6f, R = %.6f, F1 = %.6f', sc.ans_size, sc.p, sc.r, sc.f1)
             LogInfo.end_track()
     filt_states = filter(lambda _tup: _tup[0].ans_size > 0, states)
     LogInfo.end_track('%d / %d %s schemas kept with ans_size > 0.', len(filt_states), len(states), level)
     Tt.record('%s_F1' % level)
     return filt_states
示例#19
0
    def forward(self, item_wd_embedding, item_len, reuse=None):
        LogInfo.begin_track('ItemBiRNNModule forward: ')

        with tf.variable_scope('ItemBiRNNModule', reuse=reuse):
            # stamps = item_wd_embedding.get_shape().as_list()[1]
            stamps = self.item_max_len
            show_tensor(item_wd_embedding)
            birnn_inputs = tf.unstack(item_wd_embedding,
                                      num=stamps,
                                      axis=1,
                                      name='birnn_inputs')
            # rnn_input: a list of stamps elements: (batch, n_emb)
            encoder_output = self.rnn_encoder.encode(inputs=birnn_inputs,
                                                     sequence_length=item_len,
                                                     reuse=reuse)
            birnn_outputs = tf.stack(
                encoder_output.outputs, axis=1,
                name='birnn_outputs')  # (data_size, q_len, n_hidden_emb)
            LogInfo.logs('birnn_output = %s',
                         birnn_outputs.get_shape().as_list())

            sum_wd_hidden = tf.reduce_sum(birnn_outputs,
                                          axis=1)  # (data_size, n_hidden_emb)
            item_len_mat = tf.cast(tf.expand_dims(item_len, axis=1),
                                   dtype=tf.float32)  # (data_size, 1) as float
            item_wd_hidden = tf.div(
                sum_wd_hidden,
                tf.maximum(item_len_mat, 1),  # avoid dividing by 0
                name='item_wd_hidden')  # (data_size, n_hidden_emb)
            LogInfo.logs('item_wd_hidden = %s',
                         item_wd_hidden.get_shape().as_list())

        LogInfo.end_track()
        return item_wd_hidden
示例#20
0
    def __init__(self,
                 use_sparql_cache=True,
                 data_mode='Ordinal',
                 sc_mode='Skeleton',
                 root_path='/home/kangqi/workspace/PythonProject',
                 cache_dir='runnings/compQA/cache'):
        LogInfo.begin_track('Initializing InputGenerator ... ')

        assert data_mode in ('Ordinal', 'ComplexQuestions')
        assert sc_mode in ('Skeleton', 'Sk+Ordinal')

        self.data_mode = data_mode
        self.sc_mode = sc_mode
        if self.data_mode == 'Ordinal':
            self.qa_data = load_complex_questions_ordinal_only()
            self.train_qa_list, self.test_qa_list = self.qa_data
        elif self.data_mode == 'ComplexQuestions':
            self.qa_data = load_complex_questions()
            self.train_qa_list, self.test_qa_list = self.qa_data
        else:
            LogInfo.logs('Unknown data mode: %s', self.data_mode)
        self.cand_gen = CandidateGenerator(use_sparql_cache=use_sparql_cache)
        self.loss_calc = LossCalculator(driver=self.cand_gen.driver)

        #        qa_schema_score_cache_fp = '%s/%s/qa_schema_score_%s_cache' %(root_path, cache_dir, sc_mode)
        #        self.score_cache = DictCache(qa_schema_score_cache_fp)
        LogInfo.end_track()
示例#21
0
def main(args):
    qa_list = load_simpq()
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb)
    lex_name = args.lex_name
    # lex_name = 'filter-%s-xs-lex-score' % args.fb_subset
    cand_gen = SimpleQCandidateGenerator(fb_subset=args.fb_subset,
                                         lex_name=lex_name,
                                         wd_emb_util=wd_emb_util,
                                         vb=args.verbose)

    for q_idx, qa in enumerate(qa_list):
        LogInfo.begin_track('Entering Q %d / %d [%s]:',
                            q_idx, len(qa_list), qa['utterance'].encode('utf-8'))
        sub_idx = q_idx / 1000 * 1000
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx)
        link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(opt_sc_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue
        Tt.start('single_q')
        cand_gen.single_question_candgen(q_idx=q_idx, qa=qa,
                                         link_fp=link_fp, opt_sc_fp=opt_sc_fp)
        Tt.record('single_q')
        LogInfo.end_track()     # End of Q
示例#22
0
def start_entity_search(entity_linkings, conflict_matrix, tag_set):
    LogInfo.begin_track('Searching at M/E level ...')
    entity_available_combs = []  # the return value
    el_size = len(entity_linkings)
    gl_size = len(conflict_matrix)
    for mf_idx, main_focus in enumerate(entity_linkings):
        gl_pos = main_focus.gl_pos
        visit_arr = [0] * gl_size
        for conf_idx in conflict_matrix[gl_pos]:
            visit_arr[conf_idx] += 1
        gl_data_indices = [gl_pos]
        tag_elements = []  # create the initial state of search
        mid = main_focus.value
        type_list = get_entity_type(mid)
        for tp_idx, tp in enumerate(type_list):
            state_marker = [
                'M%d/%d-(t%d/%d)' %
                (mf_idx + 1, el_size, tp_idx + 1, len(type_list))
            ]
            tag_elements.append('M:%s' % tp)
            entity_search_dfs(entity_linkings=entity_linkings,
                              conflict_matrix=conflict_matrix,
                              tag_set=tag_set,
                              cur_el_idx=-1,
                              gl_data_indices=gl_data_indices,
                              tag_elements=tag_elements,
                              visit_arr=visit_arr,
                              entity_available_combs=entity_available_combs,
                              state_marker=state_marker)
            del tag_elements[-1]
    LogInfo.end_track()
    return entity_available_combs
示例#23
0
def make_combination(gather_linkings, sparql_driver, vb):
    """
    Given the E/T/Tm linkings, return all the possible combination of query structure.
    The only restrict: can't use multiple linkings with overlapped mention.
    ** Used in either WebQ or CompQ, not SimpQ **
    :param gather_linkings: list of named_tuple (detail, category, comparison, display)
    :param sparql_driver: the sparql query engine
    :param vb: verbose
    :return: the dictionary including all necessary information of a schema.
    """
    sz = len(gather_linkings)
    el_size = len(filter(lambda x: x.category == 'Entity', gather_linkings))

    # Step 1: Prepare conflict matrix
    conflict_matrix = []
    for i in range(sz):
        local_conf_list = []
        for j in range(sz):
            if is_overlap(gather_linkings[i].detail,
                          gather_linkings[j].detail):
                local_conf_list.append(j)
            elif gather_linkings[i].category == 'Type' and gather_linkings[
                    j].category == 'Type':
                local_conf_list.append(j)
                """ 180205: We add this restriction for saving time."""
                """ I thought there should be only one type constraint in the schema. """
                """ Don't make the task even more complex. """
        conflict_matrix.append(local_conf_list)

    # Step 2: start combination searching
    LogInfo.begin_track(
        'Starting searching combination (total links = %d, entities = %d):',
        len(gather_linkings), el_size)
    ground_comb_list = []  # [ (comb, path_len, sparql_query_ret) ]
    for path_len in (1, 2):
        for mf_idx, main_focus in enumerate(gather_linkings):
            if main_focus.category != 'Entity':
                continue
            visit_arr = [
                0
            ] * sz  # indicating how many conflicts at the particular searching state
            state_marker = [
                'Path-%d||F%d/%d' % (path_len, mf_idx + 1, el_size)
            ]
            cur_comb = [(0, mf_idx)]  # indicating the focus entity
            for conf_idx in conflict_matrix[mf_idx]:
                visit_arr[conf_idx] += 1
            search_start(path_len=path_len,
                         gather_linkings=gather_linkings,
                         sparql_driver=sparql_driver,
                         cur_idx=-1,
                         cur_comb=cur_comb,
                         conflict_matrix=conflict_matrix,
                         visit_arr=visit_arr,
                         ground_comb_list=ground_comb_list,
                         state_marker=state_marker,
                         vb=vb)
    LogInfo.end_track()
    return ground_comb_list
示例#24
0
def main(args):
    assert args.data_name in ('WebQ', 'CompQ')
    if args.data_name == 'WebQ':
        qa_list = load_webq()
    else:
        qa_list = load_compq()

    if args.linking_only:
        query_srv = None
        q_start = 0
        q_end = len(qa_list)
    else:
        group_idx = args.group_idx
        q_start = group_idx * 100
        q_end = group_idx * 100 + 100
        if args.data_name == 'CompQ':
            sparql_cache_fp = 'runnings/acl18_cache/group_cache/sparql.g%02d.cache' % group_idx
            q_sc_cache_fp = 'runnings/acl18_cache/group_cache/q_sc_stat.g%02d.cache' % group_idx
        else:
            sparql_cache_fp = 'runnings/acl18_cache/group_cache_%s/sparql.g%02d.cache' % (args.data_name, group_idx)
            q_sc_cache_fp = 'runnings/acl18_cache/group_cache_%s/q_sc_stat.g%02d.cache' % (args.data_name, group_idx)
        query_srv = QueryService(
            sparql_cache_fp=sparql_cache_fp,
            qsc_cache_fp=q_sc_cache_fp, vb=1
        )
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb,
                                    dim_emb=args.dim_emb)
    smart_cand_gen = SMARTCandidateGenerator(data_name=args.data_name,
                                             lex_name=args.lex_name,
                                             wd_emb_util=wd_emb_util,
                                             query_srv=query_srv,
                                             allow_forever=args.allow_forever,
                                             vb=args.verbose,
                                             simple_type_match=args.simple_type,
                                             simple_time_match=args.simple_time)

    for q_idx, qa in enumerate(qa_list):
        if q_idx < q_start or q_idx >= q_end:
            continue
        # if q_idx != 1302:
        #     continue
        LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8'))
        sub_idx = q_idx / 100 * 100
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 99)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        # save_ans_fp = '%s/%d_ans' % (sub_dir, q_idx)
        opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx)
        link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(opt_sc_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue
        Tt.start('single_q')
        smart_cand_gen.single_question_candgen(q_idx=q_idx, qa=qa,
                                               link_fp=link_fp,
                                               opt_sc_fp=opt_sc_fp,
                                               linking_only=args.linking_only)
        Tt.record('single_q')
        LogInfo.end_track()     # End of Q
示例#25
0
 def save_dicts(self):
     LogInfo.begin_track('Saving actual word/mid dict into [%s] ...', self.dict_fp)
     with open(self.dict_fp, 'w') as bw:
         cPickle.dump(self.word_dict, bw)
         LogInfo.logs('%d w_dict saved.', len(self.word_dict))
         cPickle.dump(self.mid_dict, bw)
         LogInfo.logs('%d e_dict saved.', len(self.mid_dict))
     LogInfo.end_track()
示例#26
0
 def save_init_emb(self):
     LogInfo.begin_track('Saving initial embedding matrix into [%s] ...', self.init_mat_fp)
     np.savez(self.init_mat_fp,
              word_init_emb=self.word_init_emb,
              mid_init_emb=self.mid_init_emb)
     LogInfo.logs('Word embedding: %s saved.', self.word_init_emb.shape)
     LogInfo.logs('Mid embedding: %s saved.', self.mid_init_emb.shape)
     LogInfo.end_track()
示例#27
0
def main_default(qa_list, linking_wrapper, linking_cache, sparql_driver,
                 simpq_sp_dict, args):
    q_size = len(qa_list)
    for q_idx, qa in enumerate(qa_list):
        if q_idx < args.q_start or q_idx >= args.q_end:
            continue
        LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, q_size,
                            qa['utterance'].encode('utf-8'))
        sub_idx = q_idx / 100 * 100
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 99)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        save_schema_fp = '%s/%d_schema' % (sub_dir, q_idx)
        save_ans_fp = '%s/%d_ans' % (sub_dir, q_idx)
        save_link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(save_schema_fp) \
                and os.path.isfile(save_ans_fp) \
                and os.path.isfile(save_link_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue

        target_value = qa['targetValue']
        gather_linkings, schema_json_list, schema_ans_list = \
            process_single_question(data_name=args.data_name,
                                    q=qa['utterance'],
                                    q_idx=q_idx,
                                    target_value=target_value,
                                    linking_wrapper=linking_wrapper,
                                    linking_cache=linking_cache,
                                    ground_truth=args.ground_truth,
                                    sparql_driver=sparql_driver,
                                    simpq_sp_dict=simpq_sp_dict,
                                    vb=args.verbose)

        with open(save_link_fp + '.tmp', 'w') as bw_link:
            cPickle.dump(gather_linkings, bw_link)

        bw_sc = open(save_schema_fp + '.tmp', 'w')
        bw_ans = open(save_ans_fp + '.tmp', 'w')
        for schema_json, predict_list in zip(schema_json_list,
                                             schema_ans_list):
            json.dump(predict_list, bw_ans)
            bw_ans.write('\n')
            json.dump(schema_json, bw_sc)
            bw_sc.write('\n')
        bw_sc.close()
        bw_ans.close()
        LogInfo.logs('Schema/Answer/Link saving complete.')

        shutil.move(save_schema_fp + '.tmp', save_schema_fp)
        shutil.move(save_link_fp + '.tmp', save_link_fp)
        shutil.move(save_ans_fp + '.tmp', save_ans_fp)

        del gather_linkings
        del schema_ans_list
        del schema_json_list

        LogInfo.end_track()  # End of Q
def answer_collection(args):
    assert args.data_name in ('WebQ', 'CompQ')
    result_dir = args.result_dir
    data_dir = args.data_dir
    epoch = args.epoch
    st = args.q_st
    ed = args.q_ed
    if args.data_name == 'WebQ':
        qa_list = load_webq()
        if st == ed == -1:
            st = 3778
            ed = 5810
    else:
        qa_list = load_compq()
        if st == ed == -1:
            st = 1300
            ed = 2100
    select_schema_fp = '%s/test_schema_%03d.txt' % (result_dir, epoch)
    select_dict = {}        # <q_idx, sc.ori_idx>
    test_predict_fp = '%s/test_predict_%03d.txt' % (result_dir, epoch)

    LogInfo.begin_track('Collecting answers from %s ...', select_schema_fp)
    LogInfo.logs('Original schema data: %s', data_dir)
    with open(select_schema_fp, 'r') as br:
        for line in br.readlines():
            spt = line.strip().split('\t')
            select_dict[int(spt[0])] = int(spt[1])

    save_tups = []
    for scan_idx, q_idx in enumerate(range(st, ed)):
        if scan_idx % 50 == 0:
            LogInfo.logs('%d questions scanned.', scan_idx)
        utterance = qa_list[q_idx]['utterance']
        sc_idx = select_dict.get(q_idx, -1)
        if sc_idx == -1:
            ans_list = []       # no answer at all
        else:
            div = q_idx / 100
            sub_dir = '%d-%d' % (div*100, div*100+99)
            ans_fp = '%s/%s/%d_ans' % (data_dir, sub_dir, q_idx)
            ans_line = linecache.getline(ans_fp, lineno=sc_idx+1).strip()
            # with codecs.open(ans_fp, 'r', 'utf-8') as br:
            #     for skip_tms in range(sc_idx):      # skip rows before the target schema line
            #         br.readline()
            #     ans_line = br.readline().strip()
            LogInfo.logs('q_idx=%d, lineno=%d, content=[%s]', q_idx, sc_idx+1, ans_line)
            if args.data_name == 'CompQ':
                ans_line = ans_line.lower()
            ans_list = json.loads(ans_line)
        save_tups.append((utterance, ans_list))

    with codecs.open(test_predict_fp, 'w', 'utf-8') as bw:
        for utterance, ans_list in save_tups:
            bw.write(utterance + '\t')
            json.dump(ans_list, bw)
            bw.write('\n')
    LogInfo.end_track('Predicting results saved to %s.', test_predict_fp)
    def renew_data_list(self):
        """
        Target: construct the np_data_list, storing all <q, sc> data
        np_data_list contains:
        1. q_words      (data_size, q_max_len)
        2. q_words_len  (data_size, )
        3. sc_len       (data_size, )
        4. preds        (data_size, sc_max_len, path_max_len)
        5. preds_len    (data_size, sc_max_len)
        6. pwords       (data_size, sc_max_len, pword_max_len)
        7. pwords_len   (data_size, sc_max_len)
        """
        if self.verbose >= 1:
            LogInfo.begin_track(
                '[CompqSingleDataLoader] prepare data for [%s] ...', self.mode)
        q_idx_list = get_q_range_by_mode(data_name=self.dataset.data_name,
                                         mode=self.mode)
        filt_q_idx_list = filter(lambda q: q in self.dataset.q_cand_dict,
                                 q_idx_list)
        self.question_size = len(filt_q_idx_list)
        # Got all the related questions

        emb_pools = [
        ]  # [ qwords, qwords_len, sc_len, preds, preds_len, pwords, pwords_len ]
        for _ in range(self.dataset.array_num):
            emb_pools.append([])
        """
            Different from original complex DataLoader.
            No schemas share the same qwords.
        """
        self.cand_tup_list = []
        for scan_idx, q_idx in enumerate(filt_q_idx_list):
            if self.verbose >= 1 and scan_idx % self.proc_ob_num == 0:
                LogInfo.logs('%d / %d prepared.', scan_idx,
                             len(filt_q_idx_list))
            cand_list = self.dataset.q_cand_dict[q_idx]

            # now store schema input into the corresponding position in the np_data_list
            for sc in cand_list:
                self.cand_tup_list.append((q_idx, sc))
                sc_tensor_inputs = self.dataset.get_schema_tensor_inputs(sc)
                for local_list, sc_tensor in zip(emb_pools, sc_tensor_inputs):
                    local_list.append(sc_tensor)
            # now the detail input of the schema is copied into the memory of the dataloader
            # including qwords, preds, pwords

        # Finally: merge inputs together, and produce the final np_data_list
        self.np_data_list = []
        for target_list in emb_pools:
            self.np_data_list.append(np.array(target_list, dtype='int32'))
        self.update_statistics()
        assert len(self.cand_tup_list) == self.np_data_list[0].shape[0]

        if self.verbose >= 1:
            LogInfo.logs('%d <q, sc> data collected.',
                         self.np_data_list[0].shape[0])
            LogInfo.end_track()
示例#30
0
    def post_process(self, eval_dl, detail_fp, result_fp):
        if len(self.eval_detail_dict) == 0:
            self.eval_detail_dict = {k: [] for k in self.concern_name_list}

        ret_q_score_dict = {}
        for scan_idx, (q_idx, cand) in enumerate(eval_dl.eval_sc_tup_list):
            cand.run_info = {
                k: self.eval_detail_dict[k][scan_idx]
                for k in self.concern_name_list
            }
            ret_q_score_dict.setdefault(q_idx, []).append(cand)
            # put all output results into sc.run_info

        f1_list = []
        for q_idx, score_list in ret_q_score_dict.items():
            score_list.sort(key=lambda x: x.run_info['full_score'],
                            reverse=True)  # sort by score DESC
            if len(score_list) == 0:
                f1_list.append(0.)
            else:
                f1_list.append(score_list[0].f1)
        LogInfo.logs('[%3s] Predict %d out of %d questions.', self.name,
                     len(f1_list), eval_dl.total_questions)
        ret_metric = np.sum(f1_list).astype(
            'float32') / eval_dl.total_questions

        if detail_fp is not None:
            schema_dataset = eval_dl.schema_dataset
            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            np.set_printoptions(threshold=np.nan)
            LogInfo.logs('Avg_f1 = %.6f', ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                qa = schema_dataset.qa_list[q_idx]
                q = qa['utterance']
                LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8'))
                srt_list = ret_q_score_dict[q_idx]  # already sorted
                best_label_f1 = np.max([sc.f1 for sc in srt_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, sc in enumerate(srt_list):
                    if rank < 20 or sc.f1 == best_label_f1:
                        LogInfo.begin_track(
                            '#-%04d [F1 = %.6f] [row_in_file = %d]', rank + 1,
                            sc.f1, sc.ori_idx)
                        LogInfo.logs('full_score: %.6f',
                                     sc.run_info['full_score'])
                        show_overall_detail(sc)
                        LogInfo.end_track()
                LogInfo.end_track()
            LogInfo.logs('Avg_f1 = %.6f', ret_metric)
            np.set_printoptions()  # reset output format
            LogInfo.stop_redirect()
            bw.close()

        self.ret_q_score_dict = ret_q_score_dict
        return [ret_metric]