示例#1
0
def main(args):
    qa_list = load_simpq()
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb)
    lex_name = args.lex_name
    # lex_name = 'filter-%s-xs-lex-score' % args.fb_subset
    cand_gen = SimpleQCandidateGenerator(fb_subset=args.fb_subset,
                                         lex_name=lex_name,
                                         wd_emb_util=wd_emb_util,
                                         vb=args.verbose)

    for q_idx, qa in enumerate(qa_list):
        LogInfo.begin_track('Entering Q %d / %d [%s]:',
                            q_idx, len(qa_list), qa['utterance'].encode('utf-8'))
        sub_idx = q_idx / 1000 * 1000
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx)
        link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(opt_sc_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue
        Tt.start('single_q')
        cand_gen.single_question_candgen(q_idx=q_idx, qa=qa,
                                         link_fp=link_fp, opt_sc_fp=opt_sc_fp)
        Tt.record('single_q')
        LogInfo.end_track()     # End of Q
示例#2
0
def main(args):
    assert args.data_name in ('WebQ', 'CompQ')
    if args.data_name == 'WebQ':
        qa_list = load_webq()
    else:
        qa_list = load_compq()

    if args.linking_only:
        query_srv = None
        q_start = 0
        q_end = len(qa_list)
    else:
        group_idx = args.group_idx
        q_start = group_idx * 100
        q_end = group_idx * 100 + 100
        if args.data_name == 'CompQ':
            sparql_cache_fp = 'runnings/acl18_cache/group_cache/sparql.g%02d.cache' % group_idx
            q_sc_cache_fp = 'runnings/acl18_cache/group_cache/q_sc_stat.g%02d.cache' % group_idx
        else:
            sparql_cache_fp = 'runnings/acl18_cache/group_cache_%s/sparql.g%02d.cache' % (args.data_name, group_idx)
            q_sc_cache_fp = 'runnings/acl18_cache/group_cache_%s/q_sc_stat.g%02d.cache' % (args.data_name, group_idx)
        query_srv = QueryService(
            sparql_cache_fp=sparql_cache_fp,
            qsc_cache_fp=q_sc_cache_fp, vb=1
        )
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb,
                                    dim_emb=args.dim_emb)
    smart_cand_gen = SMARTCandidateGenerator(data_name=args.data_name,
                                             lex_name=args.lex_name,
                                             wd_emb_util=wd_emb_util,
                                             query_srv=query_srv,
                                             allow_forever=args.allow_forever,
                                             vb=args.verbose,
                                             simple_type_match=args.simple_type,
                                             simple_time_match=args.simple_time)

    for q_idx, qa in enumerate(qa_list):
        if q_idx < q_start or q_idx >= q_end:
            continue
        # if q_idx != 1302:
        #     continue
        LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8'))
        sub_idx = q_idx / 100 * 100
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 99)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        # save_ans_fp = '%s/%d_ans' % (sub_dir, q_idx)
        opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx)
        link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(opt_sc_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue
        Tt.start('single_q')
        smart_cand_gen.single_question_candgen(q_idx=q_idx, qa=qa,
                                               link_fp=link_fp,
                                               opt_sc_fp=opt_sc_fp,
                                               linking_only=args.linking_only)
        Tt.record('single_q')
        LogInfo.end_track()     # End of Q
示例#3
0
 def batch_schema_f1_query(self, q_id, states, level):
     """
     perform F1 query for each schema in the state.
     :param q_id: WebQ-xxx / CompQ-xxx
     :param states: [(schema, visit_arr)]
     :param level: coarse / typed / timed / ordinal
     :return: filtered states where each schema returns at least one answer.
     """
     Tt.start('%s_F1' % level)
     LogInfo.begin_track('Calculating F1 for %d %s schemas:', len(states), level)
     for idx, (sc, _) in enumerate(states):
         if idx % 100 == 0:
             LogInfo.logs('Current: %d / %d', idx, len(states))
         sparql_str = sc.build_sparql(simple_time_match=self.simple_time_match)
         tm_comp, tm_value, ord_comp, ord_rank, agg = sc.build_aux_for_sparql()
         allow_forever = self.allow_forever if tm_comp != 'None' else ''
         # won't specific forever if no time constraints
         q_sc_key = '|'.join([q_id, sparql_str,
                              tm_comp, tm_value, allow_forever,
                              ord_comp, ord_rank, agg])
         if self.vb >= 2:
             LogInfo.begin_track('Checking schema %d / %d:', idx, len(states))
             LogInfo.logs(sc.disp_raw_path())
             LogInfo.logs('var_types: %s', sc.var_types)
             LogInfo.logs(sparql_str)
         Tt.start('query_q_sc_stat')
         sc.ans_size, sc.p, sc.r, sc.f1 = self.query_srv.query_q_sc_stat(q_sc_key)
         Tt.record('query_q_sc_stat')
         if self.vb >= 2:
             LogInfo.logs('Answers = %d, P = %.6f, R = %.6f, F1 = %.6f', sc.ans_size, sc.p, sc.r, sc.f1)
             LogInfo.end_track()
     filt_states = filter(lambda _tup: _tup[0].ans_size > 0, states)
     LogInfo.end_track('%d / %d %s schemas kept with ans_size > 0.', len(filt_states), len(states), level)
     Tt.record('%s_F1' % level)
     return filt_states
def read_schemas_from_single_file(q_idx, schema_fp, gather_linkings,
                                  sc_max_len):
    """
    Read schemas from file, seperate them into several groups
    """
    general_list = []
    strict_list = []  # 2-hop must be mediator
    elegant_list = []  # 2-hop allows pred1.range == pred2.domain
    coherent_list = []  # 2-hop allows pred1.range \in pred2.domain ()
    global_lists = [strict_list, elegant_list, coherent_list, general_list]

    used_pred_name_dict = {}
    with codecs.open(schema_fp, 'r', 'utf-8') as br:
        lines = br.readlines()
        for ori_idx, line in enumerate(lines):
            Tt.start('read_single_line')
            sc = CompqSchema.read_schema_from_json(
                q_idx,
                json_line=line,
                gather_linkings=gather_linkings,
                use_ans_type_dist=False,
                placeholder_policy='ActiveOnly',
                full_constr=True,
                fix_dir=True)
            Tt.record('read_single_line')
            sc.ori_idx = ori_idx
            # Tt.start('construct_path')
            # sc.construct_path_list()        # create the path_list on-the-fly
            # Tt.record('construct_path')
            sc_len_list[len(sc.raw_paths)] += 1
            for _, _, pred_seq in sc.raw_paths:
                for pred in pred_seq:
                    if pred not in used_pred_name_dict:
                        used_pred_name_dict[pred] = get_pred_name(pred)

            if len(sc.raw_paths) > sc_max_len:
                continue
            Tt.start('classification')
            sc_class = schema_classification(sc)
            global_lists[sc_class].append(sc)
            Tt.record('classification')
            # if q_idx == 1353:
            #     LogInfo.logs('Q-idx = %4d, Line = %4d, category = %d', q_idx, ori_idx, sc_class)
    for i in range(3):
        global_lists[i + 1] += global_lists[i]
    return global_lists, used_pred_name_dict
示例#5
0
    def optimize(self, data_loader, epoch_idx, ob_batch_num=10, summary_writer=None):
        if data_loader is None:
            return -1.

        LogInfo.logs('Debug mode: %s', self.debug_mode)
        Tt.start('optimize')
        Tt.start('prepare')
        self.prepare_data(data_loader=data_loader)
        # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        # run_metadata = tf.RunMetadata()
        # tf.reset_default_graph()
        Tt.record('prepare')
        scan_size = 0
        ret_loss = 0.
        for batch_idx in range(data_loader.n_batch):
            # point = (epoch_idx - 1) * data_loader.n_batch + batch_idx
            Tt.start('allocate')
            local_data_list, _ = data_loader.get_next_batch()
            local_size = len(local_data_list[0])    # the first dimension is always batch size
            fd = {input_tf: local_data for input_tf, local_data in zip(self.optm_input_tf_list, local_data_list)}
            Tt.record('allocate')
            Tt.start('running')
            if self.debug_mode == 'Loss':
                local_loss, summary = self.sess.run(
                    [self.avg_loss, self.optm_summary],
                    feed_dict=fd,
                    # options=run_options,
                    # run_metadata=run_metadata
                )
            elif self.debug_mode == 'Grad':
                _, local_loss, summary = self.sess.run(
                    [self.clipped_gradients, self.avg_loss, self.optm_summary],
                    feed_dict=fd,
                    # options=run_options,
                    # run_metadata=run_metadata
                )
            else:       # Update or Raw
                _, local_loss, summary = self.sess.run(
                    [self.optm_step, self.avg_loss, self.optm_summary],
                    feed_dict=fd,
                    # options=run_options,
                    # run_metadata=run_metadata
                )
            Tt.record('running')

            Tt.start('display')
            ret_loss = (ret_loss * scan_size + local_loss * local_size) / (scan_size + local_size)
            scan_size += local_size
            if (batch_idx+1) % ob_batch_num == 0:
                LogInfo.logs('[optm-%s-B%d/%d] avg_loss = %.6f, scanned = %d/%d',
                             data_loader.mode,
                             batch_idx+1,
                             data_loader.n_batch,
                             ret_loss,
                             scan_size,
                             len(data_loader))
            # if summary_writer is not None:
            #     summary_writer.add_summary(summary, point)
            #     if batch_idx == 0:
            #         summary_writer.add_run_metadata(run_metadata, 'epoch-%d' % epoch_idx)
            Tt.record('display')

        Tt.record('optimize')
        return ret_loss
示例#6
0
    def perform_linking(self, q_idx, tok_list, entity_only=False):
        tok_list = map(lambda x: x.lower(),
                       tok_list)  # all steps ignore cases.
        Tt.start('entity')
        el_list = self.q_links_dict.get(q_idx, [])
        Tt.record('entity')

        if not entity_only:
            Tt.start('type')
            tl_list = self.type_linking(tok_list)
            Tt.record('type')
            Tt.start('time')
            tml_list = self.time_linking(tok_list)
            Tt.record('time')
            Tt.start('ordinal')
            ord_list = self.ordinal_linking(tok_list)
            Tt.record('ordinal')
            gather_linkings = el_list + tl_list + tml_list + ord_list
        else:  # For SimpleQuestions
            gather_linkings = el_list

        for idx in range(len(gather_linkings)):
            gather_linkings[idx].gl_pos = idx
        return gather_linkings
示例#7
0
    def single_question_candgen(self, q_idx, qa, link_fp, opt_sc_fp):
        # =================== Linking first ==================== #
        Tt.start('linking')
        if os.path.isfile(link_fp):
            gather_linkings = []
            with codecs.open(link_fp, 'r', 'utf-8') as br:
                for line in br.readlines():
                    tup_list = json.loads(line.strip())
                    ld_dict = {k: v for k, v in tup_list}
                    gather_linkings.append(LinkData(**ld_dict))
            LogInfo.logs('Read %d links from file.', len(gather_linkings))
        else:
            tok_list = [tok.token for tok in qa['tokens']]
            gather_linkings = self.global_linker.perform_linking(
                q_idx=q_idx,
                tok_list=tok_list,
                entity_only=True
            )
        LogInfo.begin_track('Show %d E links :', len(gather_linkings))
        if self.vb >= 1:
            for gl in gather_linkings:
                LogInfo.logs(gl.display().encode('utf-8'))
        LogInfo.end_track()
        Tt.record('linking')

        # ==================== Save linking results ================ #
        if not os.path.isfile(link_fp):
            with codecs.open(link_fp + '.tmp', 'w', 'utf-8') as bw:
                for gl in gather_linkings:
                    bw.write(json.dumps(gl.serialize()) + '\n')
            shutil.move(link_fp + '.tmp', link_fp)
            LogInfo.logs('%d link data save to file.', len(gather_linkings))

        # ===================== simple predicate finding ===================== #
        gold_entity, gold_pred, _ = qa['targetValue']
        sc_list = []
        for gl_data in gather_linkings:
            entity = gl_data.value
            pred_set = self.subj_pred_dict.get(entity, set([]))
            for pred in pred_set:
                sc = CompqSchema()
                sc.hops = 1
                sc.aggregate = False
                sc.main_pred_seq = [pred]
                sc.inv_main_pred_seq = [inverse_predicate(pred) for pred in sc.main_pred_seq]  # [!p1, !p2]
                sc.inv_main_pred_seq.reverse()  # [!p2, !p1]
                sc.raw_paths = [('Main', gl_data, [pred])]
                sc.ans_size = 1
                if entity == gold_entity and pred == gold_pred:
                    sc.f1 = sc.p = sc.r = 1.
                else:
                    sc.f1 = sc.p = sc.r = 0.
                sc_list.append(sc)

        # ==================== Save schema results ================ #
        # p, r, f1, ans_size, hops, raw_paths, (agg)
        # raw_paths: (category, gl_pos, gl_mid, pred_seq)
        with codecs.open(opt_sc_fp + '.tmp', 'w', 'utf-8') as bw:
            for sc in sc_list:
                sc_info_dict = {k: getattr(sc, k) for k in ('p', 'r', 'f1', 'ans_size', 'hops')}
                if sc.aggregate is not None:
                    sc_info_dict['agg'] = sc.aggregate
                opt_raw_paths = []
                for cate, gl, pred_seq in sc.raw_paths:
                    opt_raw_paths.append((cate, gl.gl_pos, gl.value, pred_seq))
                sc_info_dict['raw_paths'] = opt_raw_paths
                bw.write(json.dumps(sc_info_dict) + '\n')
        shutil.move(opt_sc_fp + '.tmp', opt_sc_fp)
        LogInfo.logs('%d schemas successfully saved into [%s].', len(sc_list), opt_sc_fp)
示例#8
0
    cand_gen = SimpleQCandidateGenerator(fb_subset=args.fb_subset,
                                         lex_name=lex_name,
                                         wd_emb_util=wd_emb_util,
                                         vb=args.verbose)

    for q_idx, qa in enumerate(qa_list):
        LogInfo.begin_track('Entering Q %d / %d [%s]:',
                            q_idx, len(qa_list), qa['utterance'].encode('utf-8'))
        sub_idx = q_idx / 1000 * 1000
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx)
        link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(opt_sc_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue
        Tt.start('single_q')
        cand_gen.single_question_candgen(q_idx=q_idx, qa=qa,
                                         link_fp=link_fp, opt_sc_fp=opt_sc_fp)
        Tt.record('single_q')
        LogInfo.end_track()     # End of Q


if __name__ == '__main__':
    LogInfo.begin_track('[kangqi.task.compQA.candgen_acl18.simpq_candgen] ... ')
    _args = parser.parse_args()
    main(_args)
    LogInfo.end_track('All Done.')
    Tt.display()
def working_in_data(data_dir, file_list, qa_list, sc_max_len=3):

    save_fp = data_dir + '/log.schema_check'
    bw = open(save_fp, 'w')
    LogInfo.redirect(bw)

    LogInfo.begin_track('Working in %s, with schemas in %s:', data_dir,
                        file_list)
    with open(data_dir + '/' + file_list, 'r') as br:
        lines = br.readlines()

    used_pred_tup_list = []
    stat_tup_list = []
    for line_idx, line in enumerate(lines):
        if line_idx % 100 == 0:
            LogInfo.logs('Scanning %d / %d ...', line_idx, len(lines))
        line = line.strip()
        q_idx = int(line.split('/')[2].split('_')[0])
        # if q_idx != 1353:
        #     continue
        schema_fp = data_dir + '/' + line
        link_fp = schema_fp.replace('schema', 'links')
        Tt.start('read_linkings')
        gather_linkings = []
        with codecs.open(link_fp, 'r', 'utf-8') as br:
            for gl_line in br.readlines():
                tup_list = json.loads(gl_line.strip())
                ld_dict = {k: v for k, v in tup_list}
                gather_linkings.append(LinkData(**ld_dict))
        Tt.record('read_linkings')
        Tt.start('read_schema')
        global_lists, used_pred_name_dict = \
            read_schemas_from_single_file(q_idx, schema_fp, gather_linkings, sc_max_len)
        Tt.record('read_schema')
        stat_tup_list.append((q_idx, global_lists))
        used_pred_tup_list.append((q_idx, used_pred_name_dict))

    stat_tup_list.sort(
        key=lambda _tup: max_f1(_tup[1][ELEGANT]) - max_f1(_tup[1][STRICT]),
        reverse=True)

    LogInfo.logs('sc_len distribution: %s', sc_len_list)

    LogInfo.logs('Rank\tQ_idx\tstri.\teleg.\tcohe.\tgene.'
                 '\tstri._F1\teleg._F1\tcohe._F1\tgene._F1\tUtterance')
    for rank, stat_tup in enumerate(stat_tup_list):
        q_idx, global_lists = stat_tup
        size_list = ['%4d' % len(cand_list) for cand_list in global_lists]
        max_f1_list = [
            '%8.6f' % max_f1(cand_list) for cand_list in global_lists
        ]
        size_str = '\t'.join(size_list)
        max_f1_str = '\t'.join(max_f1_list)
        LogInfo.logs('%4d\t%4d\t%s\t%s\t%s', rank + 1, q_idx, size_str,
                     max_f1_str, qa_list[q_idx]['utterance'].encode('utf-8'))

    q_size = len(stat_tup_list)
    f1_upperbound_list = []
    avg_cand_size_list = []
    found_entity_list = []
    for index in range(4):
        local_max_f1_list = []
        local_cand_size_list = []
        for _, global_lists in stat_tup_list:
            sc_list = global_lists[index]
            local_max_f1_list.append(max_f1(sc_list))
            local_cand_size_list.append(len(sc_list))
        local_found_entity = len(filter(lambda y: y > 0., local_max_f1_list))
        f1_upperbound_list.append(1.0 * sum(local_max_f1_list) / q_size)
        avg_cand_size_list.append(1.0 * sum(local_cand_size_list) / q_size)
        found_entity_list.append(local_found_entity)

    LogInfo.logs('  strict: avg size = %.6f, upp F1 = %.6f',
                 avg_cand_size_list[STRICT], f1_upperbound_list[STRICT])
    for name, index in zip(('strict', 'elegant', 'coherent', 'general'),
                           (STRICT, ELEGANT, COHERENT, GENERAL)):
        avg_size = avg_cand_size_list[index]
        ratio = 1. * avg_size / avg_cand_size_list[
            STRICT] if avg_cand_size_list[STRICT] > 0 else 0.
        found_entity = found_entity_list[index]
        upp_f1 = f1_upperbound_list[index]
        gain = upp_f1 - f1_upperbound_list[STRICT]
        LogInfo.logs(
            '%8s: avg size = %.6f (%.2fx), found entity = %d, '
            'upp F1 = %.6f, gain = %.6f', name, avg_size, ratio, found_entity,
            upp_f1, gain)
    LogInfo.end_track()

    bw.close()
    LogInfo.stop_redirect()

    used_pred_fp = data_dir + '/log.used_pred_stat'
    bw = codecs.open(used_pred_fp, 'w', 'utf-8')
    LogInfo.redirect(bw)
    ratio_list = []
    for q_idx, used_pred_name_dict in used_pred_tup_list:
        unique_id_size = len(set(used_pred_name_dict.keys()))
        unique_name_size = len(set(used_pred_name_dict.values()))
        ratio = 1. * unique_name_size / unique_id_size if unique_id_size > 0 else 1.
        ratio_list.append(ratio)
        LogInfo.logs('Q = %4d, unique id = %d, unique name = %d, ratio = %.4f',
                     q_idx, unique_id_size, unique_name_size, ratio)
    avg_ratio = sum(ratio_list) / len(ratio_list)
    LogInfo.logs('avg_ratio = %.4f', avg_ratio)
    LogInfo.stop_redirect()
示例#10
0
    def coarse_search(self, path_len, entity_linkings, conflict_matrix,
                      cur_el_idx, cur_comb, visit_arr, coarse_state_list,
                      state_marker, aggregate):
        if self.vb >= 1:
            LogInfo.begin_track('[%s] (%s)', '||'.join(state_marker),
                                datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        fuzzy_query = self.build_fuzzy_query(path_len, entity_linkings,
                                             cur_comb)
        if self.vb >= 1:
            LogInfo.logs('Fuzzy query: %s', fuzzy_query)

        st = time.time()
        Tt.start('query_sparql')
        query_ret = self.query_service.query_sparql(fuzzy_query)
        Tt.record('query_sparql')
        if self.vb >= 4:
            LogInfo.logs('Recv >>>: %s', query_ret)
        filt_query_ret = predicate_filtering(query_ret=query_ret,
                                             path_len=path_len)
        if self.vb >= 1:
            LogInfo.logs('Filt_Query_Ret = %d / %d (%.3fs)',
                         len(filt_query_ret), len(query_ret),
                         time.time() - st)
        if len(filt_query_ret) == 0:
            if self.vb >= 1:
                LogInfo.end_track()
            return  # no need to search deeper

        coarse_schemas = self.build_coarse_schemas(
            path_len=path_len,
            el_linkings=entity_linkings,
            cur_comb=cur_comb,
            filt_query_ret=filt_query_ret,
            aggregate=aggregate)
        for sc in coarse_schemas:
            coarse_state_list.append(
                (sc,
                 list(visit_arr)))  # save coarse schemas into the output state

        # Ready to search deeper
        el_size = len(entity_linkings)
        for nxt_el_idx in range(cur_el_idx + 1, el_size):
            gl_pos = entity_linkings[nxt_el_idx].gl_pos
            if visit_arr[gl_pos] != 0:  # cannot be visited due to conflict
                continue
            for conf_idx in conflict_matrix[
                    gl_pos]:  # ready to enter the next state
                visit_arr[conf_idx] += 1
            for attach_idx in range(
                    1,
                    path_len + 1):  # enumerate each possible attach position
                nxt_comb = list(cur_comb)
                nxt_comb.append((attach_idx, nxt_el_idx))
                state_marker.append('%d/%d-%d' %
                                    (nxt_el_idx + 1, el_size, attach_idx))
                self.coarse_search(path_len=path_len,
                                   entity_linkings=entity_linkings,
                                   conflict_matrix=conflict_matrix,
                                   cur_el_idx=nxt_el_idx,
                                   cur_comb=nxt_comb,
                                   visit_arr=visit_arr,
                                   coarse_state_list=coarse_state_list,
                                   state_marker=state_marker,
                                   aggregate=aggregate)
                del state_marker[-1]
            for conf_idx in conflict_matrix[gl_pos]:  # return back
                visit_arr[conf_idx] -= 1
        # Ends of DFS
        if self.vb >= 1:
            LogInfo.end_track()
示例#11
0
    def single_question_candgen(self, q_idx, qa, link_fp, opt_sc_fp, linking_only):
        # =================== Linking first ==================== #
        Tt.start('linking')
        if os.path.isfile(link_fp):
            gather_linkings = []
            with codecs.open(link_fp, 'r', 'utf-8') as br:
                for line in br.readlines():
                    tup_list = json.loads(line.strip())
                    ld_dict = {k: v for k, v in tup_list}
                    gather_linkings.append(LinkData(**ld_dict))
            LogInfo.logs('Read %d links from file.', len(gather_linkings))
        else:
            tok_list = [tok.token for tok in qa['tokens']]
            gather_linkings = self.global_linker.perform_linking(q_idx=q_idx, tok_list=tok_list)
        el_size = len(filter(lambda x: x.category == 'Entity', gather_linkings))
        tl_size = len(filter(lambda x: x.category == 'Type', gather_linkings))
        tml_size = len(filter(lambda x: x.category == 'Time', gather_linkings))
        ord_size = len(filter(lambda x: x.category == 'Ordinal', gather_linkings))
        LogInfo.begin_track('Show %d E + %d T + %d Tm + %d Ord = %d linkings:',
                            el_size, tl_size, tml_size, ord_size, len(gather_linkings))
        if self.vb >= 1:
            for gl in gather_linkings:
                LogInfo.logs(gl.display())
        LogInfo.end_track()
        Tt.record('linking')

        # ==================== Save linking results ================ #
        if not os.path.isfile(link_fp):
            with codecs.open(link_fp + '.tmp', 'w', 'utf-8') as bw:
                for gl in gather_linkings:
                    bw.write(json.dumps(gl.serialize()) + '\n')
            shutil.move(link_fp + '.tmp', link_fp)
            LogInfo.logs('%d link data save to file.', len(gather_linkings))
        if linking_only:
            return

        # ===================== Prepare necessary states for linking =============== #
        q_id = '%s_%d' % (self.data_name, q_idx)
        conflict_matrix = construct_conflict_matrix(gather_linkings)
        entity_linkings = filter(lambda x: x.category == 'Entity', gather_linkings)
        type_linkings = filter(lambda x: x.category == 'Type', gather_linkings)
        time_linkings = filter(lambda x: x.category == 'Time', gather_linkings)
        ordinal_linkings = filter(lambda x: x.category == 'Ordinal', gather_linkings)

        # ============================== Searching start ========================= #
        aggregate = False
        """ 180308: always ignore aggregation, since we've found the f****d up data distribution """
        # lower_tok_list = [tok.token.lower() for tok in qa['tokens']]
        # lower_tok_str = ' '.join(lower_tok_list)
        # aggregate = lower_tok_str.startswith('how many ')
        # # apply COUNT(*) to all candidate schemas if we found "how many" at the beginning of a question

        """ ================ Step 1: coarse linking, using entities only ================ """
        LogInfo.begin_track('Coarse level searching (total entities = %d):', el_size)
        entity_linkings.sort(key=lambda _el: _el.value)
        """ We sort the linking data, for identifying potential duplicate SPARQL queries and saving time. """
        Tt.start('coarse_comb')
        coarse_states = self.cand_searcher.find_coarse_schemas(
            entity_linkings=entity_linkings, conflict_matrix=conflict_matrix, aggregate=aggregate)
        # [(schema, visit_arr)]
        LogInfo.logs('%d coarse schemas retrieved from scratch.', len(coarse_states))
        Tt.record('coarse_comb')
        coarse_states = self.batch_schema_f1_query(q_id=q_id, states=coarse_states, level='coarse')
        LogInfo.end_track('Coarse level ended, resulting in %d schemas.', len(coarse_states))

        """ ================ Step 2: adding type information ================ """
        LogInfo.begin_track('Type level searching (total types = %d):', tl_size)
        Tt.start('typed_comb')
        typed_states = []
        for idx, (coarse_sc, visit_arr) in enumerate(coarse_states):
            if idx % 100 == 0:
                LogInfo.logs('Current: %d / %d', idx, len(coarse_states))
            typed_states += self.cand_searcher.find_typed_schemas(
                type_linkings=type_linkings, conflict_matrix=conflict_matrix,
                start_schema=coarse_sc, visit_arr=visit_arr,
                simple_type_match=self.simple_type_match
            )
        LogInfo.logs('%d typed schemas retrieved from %d coarse schemas.', len(typed_states), len(coarse_states))
        Tt.record('typed_comb')
        typed_states = self.batch_schema_f1_query(q_id=q_id, states=typed_states, level='typed')
        typed_states = coarse_states + typed_states     # don't forget those schemas without type constraints
        LogInfo.end_track('Typed level ended, resulting in %d schemas.', len(typed_states))

        """ ================ Step 3: adding time information ================ """
        LogInfo.begin_track('Time level searching (total times = %d):', tml_size)
        Tt.start('timed_comb')
        timed_states = []
        for idx, (typed_sc, visit_arr) in enumerate(typed_states):
            if idx % 100 == 0:
                LogInfo.logs('Current: %d / %d', idx, len(typed_states))
            timed_states += self.cand_searcher.find_timed_schemas(
                time_linkings=time_linkings, conflict_matrix=conflict_matrix,
                start_schema=typed_sc, visit_arr=visit_arr,
                simple_time_match=self.simple_time_match    # True: degenerates into Bao
            )
        LogInfo.logs('%d timed schemas retrieved from %d typed schemas.', len(timed_states), len(typed_states))
        Tt.record('timed_comb')
        timed_states = self.batch_schema_f1_query(q_id=q_id, states=timed_states, level='timed')
        timed_states = typed_states + timed_states      # don't forget the previous schemas
        LogInfo.end_track('Time level ended, resulting in %d schemas.', len(timed_states))

        """ ================ Step 4: ordinal information as the final step ================ """
        LogInfo.begin_track('Ordinal level searching (total ordinals = %d):', ord_size)
        Tt.start('ord_comb')
        final_states = []
        for idx, (timed_sc, visit_arr) in enumerate(timed_states):
            if idx % 100 == 0:
                LogInfo.logs('Current: %d / %d', idx, len(timed_states))
            final_states += self.cand_searcher.find_ordinal_schemas(ordinal_linkings=ordinal_linkings,
                                                                    start_schema=timed_sc,
                                                                    visit_arr=visit_arr)
        LogInfo.logs('%d ordinal schemas retrieved from %d timed schemas.', len(final_states), len(timed_states))
        Tt.record('ord_comb')
        final_states = self.batch_schema_f1_query(q_id=q_id, states=final_states, level='ordinal')
        final_states = timed_states + final_states
        LogInfo.end_track('Ordinal level ended, we finally collected %d schemas.', len(final_states))

        final_schemas = [tup[0] for tup in final_states]
        self.query_srv.save_buffer()

        # ==================== Save schema results ================ #
        # p, r, f1, ans_size, hops, raw_paths, (agg)
        # raw_paths: (category, gl_pos, gl_mid, pred_seq)
        with codecs.open(opt_sc_fp + '.tmp', 'w', 'utf-8') as bw:
            for sc in final_schemas:
                sc_info_dict = {k: getattr(sc, k) for k in ('p', 'r', 'f1', 'ans_size', 'hops')}
                if sc.aggregate is not None:
                    sc_info_dict['agg'] = sc.aggregate
                opt_raw_paths = []
                for cate, gl, pred_seq in sc.raw_paths:
                    opt_raw_paths.append((cate, gl.gl_pos, gl.value, pred_seq))
                sc_info_dict['raw_paths'] = opt_raw_paths
                bw.write(json.dumps(sc_info_dict) + '\n')
        shutil.move(opt_sc_fp + '.tmp', opt_sc_fp)
        LogInfo.logs('%d schemas successfully saved into [%s].', len(final_schemas), opt_sc_fp)

        del coarse_states
        del typed_states
        del timed_states
        del final_states
示例#12
0
def main(args):
    LogInfo.begin_track('Learning starts ... ')

    # ==== Loading Necessary Util ==== #
    LogInfo.begin_track('Loading Utils ... ')
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb)
    LogInfo.end_track()

    # ==== Loading Dataset ==== #
    data_config = literal_eval(args.data_config)    # including data_name, dir, max_length and others
    data_config['wd_emb_util'] = wd_emb_util
    # data_config['kb_emb_util'] = kb_emb_util
    data_config['verbose'] = args.verbose
    dataset = QScDataset(**data_config)
    dataset.load_size()  # load size info

    # ==== Build Model First ==== #
    LogInfo.begin_track('Building Model and Session ... ')
    gpu_options = tf.GPUOptions(allow_growth=True,
                                per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            intra_op_parallelism_threads=8))
    rm_config = literal_eval(args.rm_config)    # Relation Matching
    rm_name = rm_config['name']
    del rm_config['name']
    assert rm_name in ('Compact', 'Separated')

    rm_config['n_words'] = dataset.word_size
    rm_config['n_mids'] = dataset.mid_size
    rm_config['dim_emb'] = args.dim_emb
    rm_config['q_max_len'] = dataset.q_max_len
    rm_config['sc_max_len'] = dataset.sc_max_len
    rm_config['path_max_len'] = dataset.path_max_len
    rm_config['pword_max_len'] = dataset.path_max_len * dataset.item_max_len
    rm_config['verbose'] = args.verbose
    if rm_name == 'Compact':
        LogInfo.logs('RelationMatchingKernel: Compact')
        rm_kernel = CompactRelationMatchingKernel(**rm_config)
    else:
        LogInfo.logs('RelationMatchingKernel: Separated')
        rm_kernel = SeparatedRelationMatchingKernel(**rm_config)
    el_kernel = EntityLinkingKernel(
        e_max_size=dataset.e_max_size, e_feat_len=dataset.e_feat_len, verbose=args.verbose)

    model_config = literal_eval(args.model_config)
    model_config['sess'] = sess
    model_config['objective'] = args.eval_mode      # relation_only / normal
    model_config['relation_kernel'] = rm_kernel
    model_config['entity_kernel'] = el_kernel
    model_config['extra_len'] = dataset.extra_len
    model_config['verbose'] = args.verbose
    compq_model = CompqModel(**model_config)

    LogInfo.begin_track('Showing final parameters: ')
    for var in tf.global_variables():
        LogInfo.logs('%s: %s', var.name, var.get_shape().as_list())
    LogInfo.end_track()
    saver = tf.train.Saver()

    LogInfo.begin_track('Parameter initializing ... ')
    start_epoch = 0
    best_valid_f1 = 0.
    resume_flag = False
    model_dir = None
    if args.resume_model_name not in ('', 'None'):
        model_dir = '%s/%s' % (args.output_dir, args.resume_model_name)
        if os.path.exists(model_dir):
            resume_flag = True
    if resume_flag:
        start_epoch, best_valid_f1 = load_model(saver=saver, sess=sess, model_dir=model_dir)
    else:
        dataset.load_init_emb()  # loading parameters for embedding initialize
        LogInfo.logs('Running global_variables_initializer ...')
        sess.run(tf.global_variables_initializer(),
                 feed_dict={rm_kernel.w_embedding_init: dataset.word_init_emb,
                            rm_kernel.m_embedding_init: dataset.mid_init_emb})
    LogInfo.end_track('Start Epoch = %d', start_epoch)
    LogInfo.end_track('Model Built.')
    tf.get_default_graph().finalize()

    # ==== Constructing Data_Loader ==== #
    LogInfo.begin_track('Creating DataLoader ... ')
    dataset.load_cands()  # first loading all the candidates
    if args.eval_mode == 'relation_only':
        ro_change = 0
        for cand_list in dataset.q_cand_dict.values():
            ro_change += add_relation_only_metric(cand_list)    # for "RelationOnly" evaluation
        LogInfo.logs('RelationOnly F1 change: %d schemas affected.', ro_change)

    optm_dl_config = {'dataset': dataset, 'mode': 'train',
                      'batch_size': args.optm_batch_size, 'proc_ob_num': 5000, 'verbose': args.verbose}
    eval_dl_config = dict(optm_dl_config)
    spt = args.dl_neg_mode.split('-')       # Neg-${POOR_CONTRIB}-${POOR_MAX_SAMPLE}
    optm_dl_config['poor_contribution'] = int(spt[1])
    optm_dl_config['poor_max_sample'] = int(spt[2])
    optm_dl_config['shuffle'] = False
    optm_train_data = CompqPairDataLoader(**optm_dl_config)

    eval_dl_config['batch_size'] = args.eval_batch_size
    eval_data_group = []
    for mode in ('train', 'valid', 'test'):
        eval_dl_config['mode'] = mode
        eval_data = CompqSingleDataLoader(**eval_dl_config)
        eval_data.renew_data_list()
        eval_data_group.append(eval_data)
    (eval_train_data, eval_valid_data, eval_test_data) = eval_data_group
    LogInfo.end_track()  # End of loading data & dataset

    # ==== Free memories ==== #
    for item in (wd_emb_util, dataset.wd_emb_util, data_config):
        del item

    # ==== Ready for learning ==== #
    LogInfo.begin_track('Learning start ... ')
    output_dir = args.output_dir
    if not os.path.exists(output_dir + '/detail'):
        os.makedirs(output_dir + '/detail')
    if not os.path.exists(output_dir + '/result'):
        os.makedirs(output_dir + '/result')
    if os.path.isdir(output_dir + '/TB'):
        shutil.rmtree(output_dir + '/TB')
    tf.summary.FileWriter(output_dir + '/TB/optm', sess.graph)      # saving model graph information
    # optm_summary_writer = tf.summary.FileWriter(output_dir + '/TB/optm', sess.graph)
    # eval_train_summary_writer = tf.summary.FileWriter(output_dir + '/TB/eval_train', sess.graph)
    # eval_valid_summary_writer = tf.summary.FileWriter(output_dir + '/TB/eval_valid', sess.graph)
    # eval_test_summary_writer = tf.summary.FileWriter(output_dir + '/TB/eval_test', sess.graph)
    # LogInfo.logs('TensorBoard writer defined.')
    # TensorBoard imformation

    status_fp = output_dir + '/status.csv'
    with open(status_fp, 'a') as bw:
        bw.write('%s\t%s\t%s\t%s\t%s\t%s\t%s\n' % (
            'Epoch', 'T_loss', 'T_F1', 'v_F1', 'Status', 't_F1',
            datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        ))
    patience = args.max_patience

    for epoch in range(start_epoch + 1, args.max_epoch + 1):
        if patience == 0:
            LogInfo.logs('Early stopping at epoch = %d.', epoch)
            break

        LogInfo.begin_track('Epoch %d / %d', epoch, args.max_epoch)
        update_flag = False

        LogInfo.begin_track('Optimizing ...')
        train_loss = compq_model.optimize(optm_train_data, epoch, ob_batch_num=1)
        LogInfo.end_track('T_loss = %.6f', train_loss)

        LogInfo.begin_track('Eval-Training ...')
        train_f1 = compq_model.evaluate(eval_train_data, epoch, ob_batch_num=50,
                                        detail_fp=output_dir + '/detail/train_%03d.txt' % epoch)
        LogInfo.end_track('T_F1 = %.6f', train_f1)

        LogInfo.begin_track('Eval-Validating ...')
        valid_f1 = compq_model.evaluate(eval_valid_data, epoch, ob_batch_num=50,
                                        detail_fp=output_dir + '/detail/valid_%03d.txt' % epoch)
        LogInfo.logs('v_F1 = %.6f', valid_f1)
        if valid_f1 > best_valid_f1:
            best_valid_f1 = valid_f1
            update_flag = True
            patience = args.max_patience
        else:
            patience -= 1
        LogInfo.logs('Model %s, best v_F1 = %.6f [patience = %d]',
                     'updated' if update_flag else 'stayed',
                     valid_f1,
                     patience)
        LogInfo.end_track()

        LogInfo.begin_track('Eval-Testing ... ')
        test_f1 = compq_model.evaluate(eval_test_data, epoch, ob_batch_num=20,
                                       detail_fp=output_dir + '/detail/test_%03d.txt' % epoch,
                                       result_fp=output_dir + '/result/test_schema_%03d.txt' % epoch)
        LogInfo.end_track('t_F1 = %.6f', test_f1)

        with open(status_fp, 'a') as bw:
            bw.write('%d\t%8.6f\t%8.6f\t%8.6f\t%s\t%8.6f\t%s\n' % (
                epoch, train_loss, train_f1, valid_f1,
                'UPDATE' if update_flag else str(patience), test_f1,
                datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
            ))
        save_epoch_dir = '%s/model_epoch_%d' % (output_dir, epoch)
        save_best_dir = '%s/model_best' % output_dir
        if args.save_epoch:
            delete_dir(save_epoch_dir)
            save_model(saver=saver, sess=sess, model_dir=save_epoch_dir, epoch=epoch, valid_metric=valid_f1)
            if update_flag and args.save_best:  # just create a symbolic link
                delete_dir(save_best_dir)
                os.symlink(save_epoch_dir, save_best_dir)  # symlink at directory level
        elif update_flag and args.save_best:
            delete_dir(save_best_dir)
            save_model(saver=saver, sess=sess, model_dir=save_best_dir, epoch=epoch, valid_metric=valid_f1)

        LogInfo.end_track()  # End of epoch
    LogInfo.end_track()  # End of learning

    Tt.display()
    LogInfo.end_track('All Done.')