示例#1
0
def main(args):
    if args.data_name == "SimpQ":
        qa_list = load_simpq(args.data_dir)
        output_file = "%s/SimpQ.all.links" % args.data_dir
    else:
        qa_list = load_reddit(args.data_dir, mode=args.mode)
        output_file = "%s/Reddit.%s.links" % (args.data_dir, args.mode)

    freebase_path = "%s/freebase-FB2M.txt" % args.fb_dir
    mid_name_path = "%s/S-NAP-ENO-triple.txt" % args.fb_meta_dir
    type_name_path = "%s/TS-name.txt" % args.fb_meta_dir
    pred_name_path = "%s/PS-name.txt" % args.fb_meta_dir
    entity_pop_path = "%s/entity_pop_5m.txt" % args.fb_meta_dir
    type_pop_path = "%s/type_pop.txt" % args.fb_meta_dir

    linker = LukovLinker(freebase_fp=freebase_path,
                         mid_name_fp=mid_name_path,
                         type_name_fp=type_name_path,
                         pred_name_fp=pred_name_path,
                         entity_pop_fp=entity_pop_path,
                         type_pop_fp=type_pop_path)

    LogInfo.begin_track('Linking data save to: %s' % output_file)
    with codecs.open(output_file, 'w', 'utf-8') as bw:
        for q_idx, qa in enumerate(qa_list):
            q_tokens = qa['tokens']
            if q_idx > 0 and q_idx % 10000 == 0:
                LogInfo.logs('Entering Q-%d', q_idx)
            tup = linker.link_single_question(q_tokens)
            bw.write('%04d\t%d\t%d\t%s\t%s\t%s\t%s\n' %
                     (q_idx, tup.start, tup.end, tup.mention, tup.mid,
                      tup.name, json.dumps(tup.feat_dict)))
    LogInfo.end_track()
示例#2
0
 def _load_linkings(self, links_fp):
     with codecs.open(links_fp, 'r', 'utf-8') as br:
         for line in br.readlines():
             if line.startswith('#'):
                 continue
             spt = line.strip().split('\t')
             q_idx, st, ed, mention, mid, wiki_name, feats = spt
             q_idx = int(q_idx)
             st = int(st)
             ed = int(ed)
             feat_dict = json.loads(feats)
             for k in feat_dict:
                 v = float('%.6f' % feat_dict[k])
                 feat_dict[k] = v
             link_data = LinkData(category='Entity',
                                  start=st,
                                  end=ed,
                                  mention=mention,
                                  comp='==',
                                  value=mid,
                                  name=wiki_name,
                                  link_feat=feat_dict)
             self.p_links_dict.setdefault(q_idx, []).append(link_data)
     LogInfo.logs('%d questions of link data loaded.',
                  len(self.p_links_dict))
示例#3
0
 def _load_pop_dict(self, entity_pop_fp, type_pop_fp):
     for pop_fp in [entity_pop_fp, type_pop_fp]:
         LogInfo.logs('Reading popularity from %s ...', pop_fp)
         with codecs.open(pop_fp, 'r', 'utf-8') as br:
             for line in br.readlines():
                 spt = line.strip().split('\t')
                 self.pop_dict[spt[0]] = int(spt[1])
     LogInfo.logs('%d <mid, popularity> loaded.', len(self.pop_dict))
示例#4
0
 def _load_pred(self, pred_name_fp):
     with codecs.open(pred_name_fp, 'r', 'utf-8') as br:
         for line in br.readlines():
             spt = line.strip().split('\t')
             if len(spt) < 2:
                 continue
             self.pred_set.add(spt[0])
     LogInfo.logs('%d predicates scanned.', len(self.pred_set))
示例#5
0
 def single_post_candgen(self, p_idx, post, link_fp, schema_fp):
     # =================== Linking first ==================== #
     if os.path.isfile(link_fp):
         gather_linkings = []
         with codecs.open(link_fp, 'r', 'utf-8') as br:
             for line in br.readlines():
                 tup_list = json.loads(line.strip())
                 ld_dict = {k: v for k, v in tup_list}
                 gather_linkings.append(LinkData(**ld_dict))
     else:
         gather_linkings = self.p_links_dict.get(p_idx, [])
         for idx in range(len(gather_linkings)):
             gather_linkings[idx].gl_pos = idx
     # ==================== Save linking results ================ #
     if not os.path.isfile(link_fp):
         with codecs.open(link_fp + '.tmp', 'w', 'utf-8') as bw:
             for gl in gather_linkings:
                 bw.write(json.dumps(gl.serialize()) + '\n')
         shutil.move(link_fp + '.tmp', link_fp)
     # ===================== simple predicate finding ===================== #
     sc_list = []
     for gl_data in gather_linkings:
         entity = gl_data.value
         pred_set = self.subj_pred_dict.get(entity, set([]))
         for pred in pred_set:
             sc = Schema()
             sc.hops = 1
             sc.main_pred_seq = [pred]
             sc.raw_paths = [('Main', gl_data, [pred])]
             sc.ans_size = 1
             sc_list.append(sc)
     if len(sc_list) == 0:
         LogInfo.logs(
             "=============q_idx: %d sc_list=0======================" %
             p_idx)
     # ==================== Save schema results ================ #
     # ans_size, hops, raw_paths
     # raw_paths: (category, gl_pos, gl_mid, pred_seq)
     with codecs.open(schema_fp + '.tmp', 'w', 'utf-8') as bw:
         for sc in sc_list:
             sc_info_dict = {
                 k: getattr(sc, k)
                 for k in ('ans_size', 'hops')
             }
             opt_raw_paths = []
             for cate, gl, pred_seq in sc.raw_paths:
                 opt_raw_paths.append((cate, gl.gl_pos, gl.value, pred_seq))
             sc_info_dict['raw_paths'] = opt_raw_paths
             bw.write(json.dumps(sc_info_dict) + '\n')
     shutil.move(schema_fp + '.tmp', schema_fp)
示例#6
0
 def _load_type(self, type_name_fp):
     with codecs.open(type_name_fp, 'r', 'utf-8') as br:
         for line in br.readlines():
             spt = line.strip().split('\t')
             if len(spt) < 2:
                 continue
             type_mid, type_name = spt[0], spt[1]
             surface = type_name.lower().replace('(s)', '')
             type_prefix = type_mid[:type_mid.find('.')]
             if type_prefix not in self.skip_domain_set:
                 self.surface_mid_dict.setdefault(surface,
                                                  set([])).add(type_mid)
                 self.mid_name_dict[type_mid] = type_name
                 self.type_set.add(type_mid)
     LogInfo.logs('After scanning %d types, %d <surface, mid_set> loaded.',
                  len(self.type_set), len(self.surface_mid_dict))
示例#7
0
def main(args):
    data_path = "%s/Reddit.%s.pkl" % (args.data_dir, args.mode)
    freebase_path = "%s/freebase-FB2M.txt" % args.freebase_dir
    links_path = "%s/Reddit.%s.links" % (args.data_dir, args.mode)

    with open(data_path, 'rb') as br:
        dg_list = pickle.load(br)
    LogInfo.logs('%d Reddit dialogs loaded.' % len(dg_list))

    cand_gen = RedditCandidateGenerator(freebase_fp=freebase_path,
                                        links_fp=links_path,
                                        verbose=args.verbose)

    output_dir = args.output_prefix + "_%s" % args.mode
    all_list_fp = output_dir + '/all_list'
    all_lists = []
    for p_idx, post in enumerate(dg_list):
        LogInfo.begin_track('Entering P %d / %d:', p_idx, len(dg_list))
        sub_idx = int(p_idx / 10000) * 10000
        index = 'data/%d-%d/%d_schema' % (sub_idx, sub_idx + 9999, p_idx)
        all_lists.append(index)
        sub_dir = '%s/data/%d-%d' % (output_dir, sub_idx, sub_idx + 9999)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        schema_fp = '%s/%d_schema' % (sub_dir, p_idx)
        link_fp = '%s/%d_links' % (sub_dir, p_idx)
        if os.path.isfile(schema_fp):
            LogInfo.end_track('Skip this post, already saved.')
            continue

        cand_gen.single_post_candgen(p_idx=p_idx,
                                     post=post,
                                     link_fp=link_fp,
                                     schema_fp=schema_fp)
        LogInfo.end_track()
    with open(all_list_fp, 'w') as fw:
        for i, idx_str in enumerate(all_lists):
            if i == len(all_lists) - 1:
                fw.write(idx_str)
            else:
                fw.write(idx_str + '\n')
示例#8
0
def main(args):
    data_path = "%s/simpQ.data.pkl" % args.data_dir
    freebase_path = "%s/freebase-FB2M.txt" % args.freebase_dir
    links_path = "%s/SimpQ.all.links" % args.data_dir

    with open(data_path, 'rb') as br:
        qa_list = pickle.load(br)
    LogInfo.logs('%d SimpleQuestions loaded.' % len(qa_list))

    cand_gen = SimpleQCandidateGenerator(freebase_fp=freebase_path,
                                         links_fp=links_path,
                                         verbose=args.verbose)

    all_list_fp = args.output_dir + '/all_list'
    all_lists = []
    for q_idx, qa in enumerate(qa_list):
        LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list),
                            qa['utterance'])
        sub_idx = int(q_idx / 1000) * 1000
        index = 'data/%d-%d/%d_schema' % (sub_idx, sub_idx + 999, q_idx)
        all_lists.append(index)
        sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        schema_fp = '%s/%d_schema' % (sub_dir, q_idx)
        link_fp = '%s/%d_links' % (sub_dir, q_idx)
        if os.path.isfile(schema_fp):
            LogInfo.end_track('Skip this question, already saved.')
            continue

        cand_gen.single_question_candgen(q_idx=q_idx,
                                         qa=qa,
                                         link_fp=link_fp,
                                         schema_fp=schema_fp)
        LogInfo.end_track()
    with open(all_list_fp, 'w') as fw:
        for i, idx_str in enumerate(all_lists):
            if i == len(all_lists) - 1:
                fw.write(idx_str)
            else:
                fw.write(idx_str + '\n')
示例#9
0
def load_reddit(data_dir, mode='train'):
    LogInfo.logs('Reddit initializing ... ')
    dg_list = []
    corenlp = StanfordCoreNLP(CORENLP_PATH)
    fp = '%s/%s_v3.txt' % (data_dir, mode)
    with open(fp, 'r') as br:
        for line in br:
            dg_line = json.loads(line)
            dialog = {
                'utterance': dg_line['post'].strip(),
                'tokens': dg_line['post'].split(),
                'parse': corenlp.dependency_parse(dg_line['post']),
                'response': dg_line['response'].strip(),
                'corr_responses': dg_line['corr_responses'],
                'all_triples': dg_line['all_triples'],
                'all_entities': dg_line['all_entities']
            }

            dg_list.append(dialog)
            if len(dg_list) % 10000 == 0:
                LogInfo.logs('%d scanned.', len(dg_list))
    pickle_fp = '%s/Reddit.%s.pkl' % (data_dir, mode)
    with open(pickle_fp, 'wb') as bw:
        pickle.dump(dg_list, bw)
    LogInfo.logs('%d Reddit saved in [%s].' % (len(dg_list), pickle_fp))
    return dg_list
示例#10
0
def load_simpq(data_dir):
    LogInfo.logs('SimpQ initializing ... ')
    qa_list = []
    corenlp = StanfordCoreNLP(CORENLP_PATH)
    for Tvt in ('train', 'valid', 'test'):
        fp = '%s/annotated_fb_data_%s.txt' % (data_dir, Tvt)
        with codecs.open(fp, 'r', 'utf-8') as br:
            for line in br.readlines():
                qa = {}
                s, p, o, q = line.strip().split('\t')
                s = _remove_simpq_header(s)
                p = _remove_simpq_header(p)
                o = _remove_simpq_header(o)
                qa['utterance'] = q
                qa['targetValue'] = (s, p, o)  # different from other datasets
                qa['tokens'] = corenlp.word_tokenize(qa['utterance'])
                qa['parse'] = corenlp.dependency_parse(qa['utterance'])
                qa_list.append(qa)
                if len(qa_list) % 1000 == 0:
                    LogInfo.logs('%d scanned.', len(qa_list))
    pickle_fp = '%s/simpQ.data.pkl' % data_dir
    with open(pickle_fp, 'wb') as bw:
        pickle.dump(qa_list, bw)
    LogInfo.logs('%d SimpleQuestions loaded.' % len(qa_list))
    return qa_list
示例#11
0
 def _load_fb_subset(self, freebase_fp):
     LogInfo.begin_track('Loading freebase subset from [%s] ...',
                         freebase_fp)
     prefix = 'www.freebase.com/'
     pref_len = len(prefix)
     with codecs.open(freebase_fp, 'r', 'utf-8') as br:
         lines = br.readlines()
     LogInfo.logs('%d lines loaded.', len(lines))
     for line_idx, line in enumerate(lines):
         if line_idx > 0 and line_idx % 500000 == 0:
             LogInfo.logs('Current: %d / %d', line_idx, len(lines))
         s, p, _ = line.strip().split('\t')
         s = s[pref_len:].replace('/', '.')
         self.subj_pred_keys.add(s)
     LogInfo.logs('%d related entities loaded.', len(self.subj_pred_keys))
     LogInfo.end_track()
示例#12
0
 def _load_mid(self, mid_name_fp, allow_alias=False):
     LogInfo.begin_track('Loading surface --> mid dictionary from [%s] ...',
                         mid_name_fp)
     with codecs.open(mid_name_fp, 'r', 'utf-8') as br:
         scan = 0
         while True:
             line = br.readline()
             if line is None or line == '':
                 break
             spt = line.strip().split('\t')
             if len(spt) < 3:
                 continue
             mid = spt[0]
             name = spt[2]
             surface = name.lower()  # save lowercase as searching entrance
             skip = False  # ignore some subjects at certain domain
             mid_prefix_pos = mid.find('.')
             if mid_prefix_pos == -1:
                 skip = True
             else:
                 mid_prefix = mid[:mid_prefix_pos]
                 if mid_prefix in self.skip_domain_set:
                     skip = True
             if not skip:
                 if spt[1] == 'type.object.name':
                     self.mid_name_dict[mid] = name
                 if spt[1] == 'type.object.name' or allow_alias:
                     self.surface_mid_dict.setdefault(surface,
                                                      set([])).add(mid)
             scan += 1
             if scan % 100000 == 0:
                 LogInfo.logs('%d lines scanned.', scan)
     LogInfo.logs('%d lines scanned.', scan)
     LogInfo.logs('%d <surface, mid_set> loaded.',
                  len(self.surface_mid_dict))
     LogInfo.logs('%d <mid, name> loaded.', len(self.mid_name_dict))
     LogInfo.end_track()
示例#13
0
    def _load_fb_subset(self, fb_fp):
        LogInfo.begin_track('Loading freebase subset from [%s] ...', fb_fp)
        prefix = 'www.freebase.com/'
        pref_len = len(prefix)

        with codecs.open(fb_fp, 'r', 'utf-8') as br:
            lines = br.readlines()
        LogInfo.logs('%d lines loaded.', len(lines))
        for line_idx, line in enumerate(lines):
            if line_idx % 500000 == 0:
                LogInfo.logs('Current: %d / %d', line_idx, len(lines))
            s, p, _ = line.strip().split('\t')
            s = s[pref_len:].replace('/', '.')
            p = p[pref_len:].replace('/', '.')
            self.subj_pred_dict.setdefault(s, set([])).add(p)
        LogInfo.logs('%d related entities and %d <S, P> pairs saved.',
                     len(self.subj_pred_dict),
                     sum([len(v) for v in self.subj_pred_dict.values()]))
        LogInfo.end_track()
示例#14
0
def main(args):

    # ==== Loading Necessary Utils ====
    LogInfo.begin_track('Loading Utils ... ')
    wd_emb_util = WordEmbeddingUtil(emb_dir=args.emb_dir, dim_emb=args.dim_emb)
    freebase_helper = FreebaseHelper(meta_dir=args.fb_meta_dir)
    LogInfo.end_track()

    # ==== Loading Dataset ====
    LogInfo.begin_track('Creating Dataset ... ')
    schema_dataset = SchemaDataset(data_dir=args.data_dir,
                                   candgen_dir=args.candgen_dir,
                                   schema_level=args.schema_level,
                                   freebase_helper=freebase_helper)
    schema_dataset.load_all_data()
    active_dicts = schema_dataset.active_dicts
    qa_list = schema_dataset.qa_list
    feature_helper = FeatureHelper(active_dicts,
                                   qa_list,
                                   freebase_helper,
                                   path_max_size=args.path_max_size,
                                   qw_max_len=args.qw_max_len,
                                   pw_max_len=args.pw_max_len,
                                   pseq_max_len=args.pseq_max_len)
    ds_builder = SchemaBuilder(schema_dataset=schema_dataset,
                               feature_helper=feature_helper,
                               neg_f1_ths=args.neg_f1_ths,
                               neg_max_sample=args.neg_max_sample,
                               neg_strategy=args.neg_strategy)
    LogInfo.end_track()

    # ==== Building Model ====
    LogInfo.begin_track('Building Model and Session ... ')
    model_config = {
        'qw_max_len': args.qw_max_len,
        'pw_max_len': args.pw_max_len,
        'path_max_size': args.path_max_size,
        'pseq_max_len': args.pseq_max_len,
        'dim_emb': args.dim_emb,
        'w_emb_fix': args.w_emb_fix,
        'n_words': args.n_words,
        'n_mids': args.n_mids,
        'n_paths': args.n_paths,
        'drop_rate': args.drop_rate,
        'rnn_config': {
            'cell_class': args.cell_class,
            'num_units': args.num_units,
            'num_layers': args.num_layers
        },
        'att_config': {
            'att_func': args.att_func,
            'dim_att_hidden': args.dim_att_hidden
        },
        'path_usage': args.path_usage,
        'sent_usage': args.sent_usage,
        'seq_merge_mode': args.seq_merge_mode,
        'scoring_mode': args.scoring_mode,
        'final_func': args.final_func,
        'loss_margin': args.loss_margin,
        'optm_name': args.optm_name,
        'learning_rate': args.lr_rate
    }
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    with open("%s/config.json" % args.output_dir, 'w') as fw:
        json.dump(model_config, fw)

    kbqa_model = KbqaModel(**model_config)

    LogInfo.logs('Showing final parameters: ')
    for var in tf.global_variables():
        LogInfo.logs('%s: %s', var.name, var.get_shape().as_list())

    LogInfo.end_track()

    # ==== Focused on specific params ====
    if args.final_func == 'bilinear':
        focus_param_name_list = ['rm_task/rm_forward/bilinear_mat']
    else:  # mlp
        focus_param_name_list = [
            'rm_task/rm_forward/fc1/weights', 'rm_task/rm_forward/fc1/biases',
            'rm_task/rm_forward/fc2/weights', 'rm_task/rm_forward/fc2/biases'
        ]
    focus_param_list = []
    with tf.variable_scope('', reuse=tf.AUTO_REUSE):
        for param_name in focus_param_name_list:
            try:
                var = tf.get_variable(name=param_name)
                focus_param_list.append(var)
            except ValueError:
                LogInfo.logs("ValueError occured for %s!" % param_name)
                pass
    LogInfo.begin_track('Showing %d concern parameters: ',
                        len(focus_param_list))
    for name, tensor in zip(focus_param_name_list, focus_param_list):
        LogInfo.logs('%s --> %s', name, tensor.get_shape().as_list())
    LogInfo.end_track()

    # ==== Initializing model ====
    saver = tf.train.Saver()
    gpu_options = tf.GPUOptions(
        allow_growth=True, per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            intra_op_parallelism_threads=8))
    LogInfo.begin_track('Running global_variables_initializer ...')
    start_epoch = 0
    best_valid_f1 = 0.
    resume_flag = False
    model_dir = None
    if args.resume_model_name not in ('', 'None'):
        model_dir = '%s/%s' % (args.output_dir, args.resume_model_name)
        if os.path.exists(model_dir):
            resume_flag = True
    if resume_flag:
        start_epoch, best_valid_f1 = model_util.load_model(saver=saver,
                                                           sess=sess,
                                                           model_dir=model_dir)
    else:
        dep_simulate = True if args.dep_simulate == 'True' else False
        wd_emb_mat = wd_emb_util.produce_active_word_embedding(
            active_word_dict=schema_dataset.active_dicts['word'],
            dep_simulate=dep_simulate)
        pa_emb_mat = np.random.uniform(
            low=-0.1,
            high=0.1,
            size=(model_config['n_paths'],
                  model_config['dim_emb'])).astype('float32')
        mid_emb_mat = np.random.uniform(
            low=-0.1,
            high=0.1,
            size=(model_config['n_mids'],
                  model_config['dim_emb'])).astype('float32')
        LogInfo.logs('%s random path embedding created.', pa_emb_mat.shape)
        LogInfo.logs('%s random mid embedding created.', mid_emb_mat.shape)
        sess.run(tf.global_variables_initializer(),
                 feed_dict={
                     kbqa_model.w_embedding_init: wd_emb_mat,
                     kbqa_model.p_embedding_init: pa_emb_mat,
                     kbqa_model.m_embedding_init: mid_emb_mat
                 })
    LogInfo.end_track('Model build complete.')

    # ==== Running optm / eval ====
    optimizer = Optimizer(model=kbqa_model, sess=sess)
    evaluator = Evaluator(model=kbqa_model, sess=sess)
    optm_data_loader = ds_builder.build_optm_dataloader(
        optm_batch_size=args.optm_batch_size)
    eval_data_list = ds_builder.build_eval_dataloader(
        eval_batch_size=args.eval_batch_size)

    if not os.path.exists('%s/detail' % args.output_dir):
        os.mkdir('%s/detail' % args.output_dir)
    if not os.path.exists('%s/result' % args.output_dir):
        os.mkdir('%s/result' % args.output_dir)

    LogInfo.begin_track('Learning start ...')

    patience = args.max_patience
    for epoch in range(start_epoch + 1, args.max_epoch + 1):
        if patience == 0:
            LogInfo.logs('Early stopping at epoch = %d.', epoch)
            break
        update_flag = False
        disp_item_dict = {'Epoch': epoch}

        LogInfo.begin_track('Epoch %d / %d', epoch, args.max_epoch)

        LogInfo.begin_track('Optimizing ... ')
        optimizer.optimize_all(optm_data_loader=optm_data_loader)

        LogInfo.logs('loss = %.6f', optimizer.ret_loss)
        disp_item_dict['rm_loss'] = optimizer.ret_loss
        LogInfo.end_track()

        LogInfo.begin_track('Evaluation:')
        for mark, eval_dl in zip(['train', 'valid', 'test'], eval_data_list):
            LogInfo.begin_track('Eval-%s ...', mark)
            disp_key = '%s_F1' % mark
            detail_fp = '%s/detail/%s.tmp' % (args.output_dir, mark)
            result_fp = '%s/result/%s.%03d.result' % (args.output_dir, mark,
                                                      epoch)
            disp_item_dict[disp_key] = evaluator.evaluate_all(
                eval_data_loader=eval_dl,
                detail_fp=detail_fp,
                result_fp=result_fp)
            LogInfo.end_track()
        LogInfo.end_track()

        # Display & save states (results, details, params)
        cur_valid_f1 = disp_item_dict['valid_F1']
        if cur_valid_f1 > best_valid_f1:
            best_valid_f1 = cur_valid_f1
            update_flag = True
            patience = args.max_patience
            save_best_dir = '%s/model_best' % args.output_dir
            model_util.delete_dir(save_best_dir)
            model_util.save_model(saver=saver,
                                  sess=sess,
                                  model_dir=save_best_dir,
                                  epoch=epoch,
                                  valid_metric=best_valid_f1)
        else:
            patience -= 1
        LogInfo.logs('Model %s, best valid_F1 = %.6f [patience = %d]',
                     'updated' if update_flag else 'stayed', cur_valid_f1,
                     patience)
        disp_item_dict['Status'] = 'UPDATE' if update_flag else str(patience)
        disp_item_dict['Time'] = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        status_fp = '%s/status.txt' % args.output_dir
        disp_header_list = model_util.construct_display_header()
        if epoch == 1:
            with open(status_fp, 'w') as bw:
                write_str = ''.join(disp_header_list)
                bw.write(write_str + '\n')
        with open(status_fp, 'a') as bw:
            write_str = ''
            for item_idx, header in enumerate(disp_header_list):
                if header.endswith(' ') or header == '\t':
                    write_str += header
                else:
                    val = disp_item_dict.get(header, '--------')
                    if isinstance(val, float):
                        write_str += '%8.6f' % val
                    else:
                        write_str += str(val)
            bw.write(write_str + '\n')

        LogInfo.logs('Output concern parameters ... ')
        # don't need any feeds, since we focus on parameters
        param_result_list = sess.run(focus_param_list)
        param_result_dict = {}
        for param_name, param_result in zip(focus_param_name_list,
                                            param_result_list):
            param_result_dict[param_name] = param_result

        with open(args.output_dir + '/detail/param.%03d.pkl' % epoch,
                  'wb') as bw:
            pickle.dump(param_result_dict, bw)
        LogInfo.logs('Concern parameters saved.')

        if update_flag:
            with open(args.output_dir + '/detail/param.best.pkl', 'wb') as bw:
                pickle.dump(param_result_dict, bw)
            # save the latest details
            for mode in ['train', 'valid', 'test']:
                src = '%s/detail/%s.tmp' % (args.output_dir, mode)
                dest = '%s/detail/%s.best' % (args.output_dir, mode)
                if os.path.isfile(src):
                    shutil.move(src, dest)

        LogInfo.end_track()  # end of epoch
    LogInfo.end_track()  # end of learning
示例#15
0
    def single_question_candgen(self, q_idx, qa, link_fp, schema_fp):
        # =================== Linking first ==================== #
        if os.path.isfile(link_fp):
            gather_linkings = []
            with codecs.open(link_fp, 'r', 'utf-8') as br:
                for line in br.readlines():
                    tup_list = json.loads(line.strip())
                    ld_dict = {k: v for k, v in tup_list}
                    gather_linkings.append(LinkData(**ld_dict))
            LogInfo.logs('Read %d links from file.', len(gather_linkings))
        else:
            gather_linkings = self.q_links_dict.get(q_idx, [])
            for idx in range(len(gather_linkings)):
                gather_linkings[idx].gl_pos = idx

        LogInfo.begin_track('Show %d E links :', len(gather_linkings))
        if self.verbose >= 1:
            for gl in gather_linkings:
                LogInfo.logs(gl.display())
        LogInfo.end_track()
        # ==================== Save linking results ================ #
        if not os.path.isfile(link_fp):
            with codecs.open(link_fp + '.tmp', 'w', 'utf-8') as bw:
                for gl in gather_linkings:
                    bw.write(json.dumps(gl.serialize()) + '\n')
            shutil.move(link_fp + '.tmp', link_fp)
            LogInfo.logs('%d link data save to file.', len(gather_linkings))
        # ===================== simple predicate finding ===================== #
        gold_entity, gold_pred, _ = qa['targetValue']
        sc_list = []
        for gl_data in gather_linkings:
            entity = gl_data.value
            pred_set = self.subj_pred_dict.get(entity, set([]))
            for pred in pred_set:
                sc = Schema()
                sc.hops = 1
                sc.aggregate = False
                sc.main_pred_seq = [pred]
                sc.raw_paths = [('Main', gl_data, [pred])]
                sc.ans_size = 1
                if entity == gold_entity and pred == gold_pred:
                    sc.f1 = sc.p = sc.r = 1.
                else:
                    sc.f1 = sc.p = sc.r = 0.
                sc_list.append(sc)
        # ==================== Save schema results ================ #
        # p, r, f1, ans_size, hops, raw_paths, (agg)
        # raw_paths: (category, gl_pos, gl_mid, pred_seq)
        with codecs.open(schema_fp + '.tmp', 'w', 'utf-8') as bw:
            for sc in sc_list:
                sc_info_dict = {
                    k: getattr(sc, k)
                    for k in ('p', 'r', 'f1', 'ans_size', 'hops')
                }
                if sc.aggregate is not None:
                    sc_info_dict['agg'] = sc.aggregate
                opt_raw_paths = []
                for cate, gl, pred_seq in sc.raw_paths:
                    opt_raw_paths.append((cate, gl.gl_pos, gl.value, pred_seq))
                sc_info_dict['raw_paths'] = opt_raw_paths
                bw.write(json.dumps(sc_info_dict) + '\n')
        shutil.move(schema_fp + '.tmp', schema_fp)
        LogInfo.logs('%d schemas successfully saved into [%s].', len(sc_list),
                     schema_fp)