Пример #1
0
    def post_process(self, eval_dl, detail_fp, result_fp):
        if len(self.eval_detail_dict) == 0:
            self.eval_detail_dict = {k: [] for k in self.concern_name_list}

        ret_q_score_dict = {}
        for scan_idx, (q_idx, cand) in enumerate(eval_dl.eval_sc_tup_list):
            cand.run_info = {
                k: self.eval_detail_dict[k][scan_idx]
                for k in self.concern_name_list
            }
            ret_q_score_dict.setdefault(q_idx, []).append(cand)
            # put all output results into sc.run_info

        f1_list = []
        for q_idx, score_list in ret_q_score_dict.items():
            score_list.sort(key=lambda x: x.run_info['full_score'],
                            reverse=True)  # sort by score DESC
            if len(score_list) == 0:
                f1_list.append(0.)
            else:
                f1_list.append(score_list[0].f1)
        LogInfo.logs('[%3s] Predict %d out of %d questions.', self.name,
                     len(f1_list), eval_dl.total_questions)
        ret_metric = np.sum(f1_list).astype(
            'float32') / eval_dl.total_questions

        if detail_fp is not None:
            schema_dataset = eval_dl.schema_dataset
            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            np.set_printoptions(threshold=np.nan)
            LogInfo.logs('Avg_f1 = %.6f', ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                qa = schema_dataset.qa_list[q_idx]
                q = qa['utterance']
                LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8'))
                srt_list = ret_q_score_dict[q_idx]  # already sorted
                best_label_f1 = np.max([sc.f1 for sc in srt_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, sc in enumerate(srt_list):
                    if rank < 20 or sc.f1 == best_label_f1:
                        LogInfo.begin_track(
                            '#-%04d [F1 = %.6f] [row_in_file = %d]', rank + 1,
                            sc.f1, sc.ori_idx)
                        LogInfo.logs('full_score: %.6f',
                                     sc.run_info['full_score'])
                        show_overall_detail(sc)
                        LogInfo.end_track()
                LogInfo.end_track()
            LogInfo.logs('Avg_f1 = %.6f', ret_metric)
            np.set_printoptions()  # reset output format
            LogInfo.stop_redirect()
            bw.close()

        self.ret_q_score_dict = ret_q_score_dict
        return [ret_metric]
Пример #2
0
def post_process(q_cands_tup_list, predict_score_list, eval_dl, detail_fp):
    scan_idx = 0
    for q_idx, cand_list in q_cands_tup_list:
        for cand in cand_list:
            cand.run_info['ltr_score'] = predict_score_list[scan_idx]
            scan_idx += 1
    assert len(predict_score_list) == scan_idx

    f1_list = []
    for q_idx, cand_list in q_cands_tup_list:
        cand_list.sort(key=lambda x: x.run_info['ltr_score'],
                       reverse=True)  # sort by score DESC
        if len(cand_list) == 0:
            f1_list.append(0.)
        else:
            f1_list.append(cand_list[0].f1)
    LogInfo.logs('[ltr-%s] Predict %d out of %d questions.', eval_dl.mode,
                 len(f1_list), eval_dl.total_questions)
    ret_metric = np.sum(f1_list).astype('float32') / eval_dl.total_questions

    if detail_fp is not None:
        schema_dataset = eval_dl.schema_dataset
        bw = open(detail_fp, 'w')
        LogInfo.redirect(bw)
        np.set_printoptions(threshold=np.nan)
        LogInfo.logs('Avg_f1 = %.6f', ret_metric)

        for q_idx, cand_list in q_cands_tup_list:
            qa = schema_dataset.qa_list[q_idx]
            q = qa['utterance']
            LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8'))
            best_label_f1 = np.max([sc.f1 for sc in cand_list])
            best_label_f1 = max(best_label_f1, 0.000001)
            for rank, sc in enumerate(cand_list):
                if rank < 20 or sc.f1 == best_label_f1:
                    LogInfo.begin_track(
                        '#-%04d [F1 = %.6f] [row_in_file = %d]', rank + 1,
                        sc.f1, sc.ori_idx)
                    LogInfo.logs('ltr_score: %.6f', sc.run_info['ltr_score'])
                    show_overall_detail(sc)
                    LogInfo.end_track()
            LogInfo.end_track()
        LogInfo.logs('Avg_f1 = %.6f', ret_metric)
        np.set_printoptions()  # reset output format
        LogInfo.stop_redirect()
        bw.close()

    LogInfo.logs('[ltr] %s_F1 = %.6f', eval_dl.mode, ret_metric)
    return ret_metric
Пример #3
0
def main(qa_list, detail_fp, out_eval_fp, out_anno_fp, bao_predict_fp, data_dir):

    bao_info_dict = read_bao(bao_predict_fp, qa_list)
    our_info_dict = read_ours(detail_fp)
    for q_idx in range(1300, 2100):
        if q_idx not in our_info_dict:
            our_info_dict[q_idx] = {'F1': 0., 'sc_line': -1}
    disp_list = range(1300, 2100)
    disp_list.sort(key=lambda idx: our_info_dict[idx]['F1'])
    with open(out_eval_fp, 'w') as bw_eval, open(out_anno_fp, 'w') as bw_anno:
        LogInfo.redirect(bw_eval)
        LogInfo.begin_track('Showing comparison: ')
        LogInfo.logs('Our: %s', detail_fp)
        LogInfo.logs('Bao: %s', bao_predict_fp)
        for q_idx in disp_list:
            LogInfo.logs('')
            LogInfo.logs('==================================================================')
            LogInfo.logs('')
            LogInfo.begin_track('Q-%d: [%s]', q_idx, qa_list[q_idx]['utterance'].encode('utf-8'))
            our_f1 = our_info_dict[q_idx]['F1']
            bao_f1 = bao_info_dict[q_idx]['F1']
            LogInfo.logs('F1_gain = %.6f, Our_F1 = %.6f, Bao_F1 = %.6f',
                         our_f1 - bao_f1, our_f1, bao_f1)
            LogInfo.logs('Gold: %s', qa_list[q_idx]['targetValue'])
            LogInfo.logs('Bao: %s', bao_info_dict[q_idx]['predict_list'])
            LogInfo.begin_track('Our sc_line: %d', our_info_dict[q_idx]['sc_line'])
            retrieve_schema(data_dir=data_dir, q_idx=q_idx, line_no=our_info_dict[q_idx]['sc_line'])
            LogInfo.end_track()     # end of schema display
            LogInfo.end_track()     # end of the question

            bw_anno.write('Q-%04d\t%s\t%s\n' % (q_idx, qa_list[q_idx]['utterance'].encode('utf-8'),
                                                qa_list[q_idx]['targetValue']))
            bw_anno.write('Our_F1: %.6f\n' % our_f1)
            bw_anno.write('1: []\n')
            bw_anno.write('2: []\n')
            bw_anno.write('3: []\n')
            bw_anno.write('4: []\n')
            bw_anno.write('5: []\n')
            bw_anno.write('\n\n')

        LogInfo.end_track()
        LogInfo.stop_redirect()
Пример #4
0
def main(args):
    # ==== Optm & Eval register ==== #
    # ltr: learning-to-rank; full: fully-connected layer as the last layer
    full_optm_method = args.full_optm_method
    if full_optm_method in ('el', 'rm'):  # sub-task only mode
        optm_tasks = eval_tasks = [full_optm_method]
    else:  # ltr or full
        optm_tasks = ['el', 'rm', 'full']
        eval_tasks = ['el', 'rm', 'full']
        # all sub-tasks needed, including full
        # we need direct comparison between full and ltr
        if full_optm_method == 'ltr':
            eval_tasks.append('ltr')
    LogInfo.logs('full_optm_method: %s', full_optm_method)
    LogInfo.logs('optimize tasks: %s', optm_tasks)
    LogInfo.logs('evaluate tasks: %s', eval_tasks)

    # ==== Loading Necessary Util ==== #
    LogInfo.begin_track('Loading Utils ... ')
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb)
    wd_emb_util.load_word_indices()
    wd_emb_util.load_mid_indices()
    LogInfo.end_track()

    # ==== Loading Dataset ==== #
    LogInfo.begin_track('Creating Dataset ... ')
    data_config = literal_eval(args.data_config)
    data_config['wd_emb_util'] = wd_emb_util
    data_config['verbose'] = args.verbose
    schema_dataset = SchemaDatasetACL18(**data_config)

    LogInfo.end_track()

    # ==== Building Model ==== #
    LogInfo.begin_track('Building Model and Session ... ')
    gpu_options = tf.GPUOptions(
        allow_growth=True, per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            intra_op_parallelism_threads=8))
    model_config = literal_eval(args.model_config)
    el_use_type = model_config['el_kernel_conf']['use_type']
    for key in ('q_max_len', 'sc_max_len', 'path_max_len', 'pword_max_len',
                'type_dist_len', 'use_ans_type_dist'):
        model_config[key] = getattr(schema_dataset, key)
    for key in ('n_words', 'n_mids', 'dim_emb'):
        model_config[key] = getattr(wd_emb_util, key)
    model_config['el_feat_size'] = 5  # TODO: manually assigned
    model_config['extra_feat_size'] = 16
    if full_optm_method != 'full':
        model_config[
            'full_back_prop'] = False  # make sure rm/el optimizes during all epochs
    full_back_prop = model_config['full_back_prop']
    compq_mt_model = CompqMultiTaskModel(**model_config)

    LogInfo.begin_track('Showing final parameters: ')
    for var in tf.global_variables():
        LogInfo.logs('%s: %s', var.name, var.get_shape().as_list())
    LogInfo.end_track()

    focus_param_name_list = [
        'el_kernel/out_fc/weights', 'el_kernel/out_fc/biases',
        'full_task/final_fc/weights', 'full_task/final_fc/biases',
        'abcnn1_rm_kernel/sim_ths', 'abcnn2_rm_kernel/sim_ths'
    ]
    focus_param_list = []
    with tf.variable_scope('', reuse=tf.AUTO_REUSE):
        for param_name in focus_param_name_list:
            try:
                var = tf.get_variable(name=param_name)
                focus_param_list.append(var)
            except ValueError:
                pass
    LogInfo.begin_track('Showing %d concern parameters: ',
                        len(focus_param_list))
    for name, tensor in zip(focus_param_name_list, focus_param_list):
        LogInfo.logs('%s --> %s', name, tensor.get_shape().as_list())
    LogInfo.end_track()

    saver = tf.train.Saver()
    LogInfo.begin_track('Running global_variables_initializer ...')
    start_epoch = 0
    best_valid_f1 = 0.
    resume_flag = False
    model_dir = None
    if args.resume_model_name not in ('', 'None'):
        model_dir = '%s/%s' % (args.output_dir, args.resume_model_name)
        if os.path.exists(model_dir):
            resume_flag = True
    if resume_flag:
        start_epoch, best_valid_f1 = load_model(saver=saver,
                                                sess=sess,
                                                model_dir=model_dir)
    else:
        wd_emb_util.load_word_embeddings()
        wd_emb_util.load_mid_embeddings()
        sess.run(tf.global_variables_initializer(),
                 feed_dict={
                     compq_mt_model.w_embedding_init:
                     wd_emb_util.word_emb_matrix,
                     compq_mt_model.m_embedding_init:
                     wd_emb_util.mid_emb_matrix
                 })
    LogInfo.end_track('Start Epoch = %d', start_epoch)
    LogInfo.end_track('Model build complete.')

    # ==== Register optm / eval ==== #
    el_optimizer = EntityLinkingOptimizer(compq_mt_model=compq_mt_model,
                                          sess=sess,
                                          ob_batch_num=100)
    el_evaluator = EntityLinkingEvaluator(compq_mt_model=compq_mt_model,
                                          sess=sess,
                                          ob_batch_num=100)
    rm_optimizer = RelationMatchingOptimizer(compq_mt_model=compq_mt_model,
                                             sess=sess,
                                             ob_batch_num=100)
    rm_evaluator = RelationMatchingEvaluator(compq_mt_model=compq_mt_model,
                                             sess=sess,
                                             ob_batch_num=100)
    full_optimizer = FullTaskOptimizer(compq_mt_model=compq_mt_model,
                                       sess=sess,
                                       ob_batch_num=100)
    full_evaluator = FullTaskEvaluator(compq_mt_model=compq_mt_model,
                                       sess=sess,
                                       ob_batch_num=100)
    LogInfo.logs('Optimizer & Evaluator defined for RM, EL and FULL.')

    schema_dataset.load_smart_cands()

    # ==== Iteration begins ==== #
    output_dir = args.output_dir
    if not os.path.exists(output_dir + '/detail'):
        os.makedirs(output_dir + '/detail')
    if not os.path.exists(output_dir + '/result'):
        os.makedirs(output_dir + '/result')

    LogInfo.begin_track('Learning start ...')
    patience = args.max_patience

    status_fp = output_dir + '/status.csv'
    raw_header_list = ['Epoch']
    for task_name in ('ltr', 'full', 'el',
                      'rm'):  # all possibilities of Optm/T/v/t
        local_header_list = []
        if task_name in optm_tasks:
            local_header_list.append('%s_loss' % task_name)
        if task_name in eval_tasks:
            for mark in 'Tvt':
                local_header_list.append('%s_%s_F1' % (task_name, mark))
        if len(local_header_list) > 0:
            raw_header_list.append(' |  ')
        raw_header_list += local_header_list
    raw_header_list += [' |  ', 'Status', 'Time']
    disp_header_list = []
    no_tab = True
    for idx, header in enumerate(
            raw_header_list):  # dynamic add \t into raw headers
        if not (no_tab or header.endswith(' ')):
            disp_header_list.append('\t')
        disp_header_list.append(header)
        no_tab = header.endswith(' ')
    with open(status_fp, 'a') as bw:
        write_str = ''.join(disp_header_list)
        bw.write(write_str + '\n')

    if full_back_prop:
        LogInfo.logs('full_back_prop = %s, pre_train_steps = %d.',
                     full_back_prop, args.pre_train_steps)
    else:
        LogInfo.logs('no pre-train available.')
    for epoch in range(start_epoch + 1, args.max_epoch + 1):
        if patience == 0:
            LogInfo.logs('Early stopping at epoch = %d.', epoch)
            break
        update_flag = False
        disp_item_dict = {'Epoch': epoch}

        LogInfo.begin_track('Epoch %d / %d', epoch, args.max_epoch)

        LogInfo.begin_track('Generating dynamic schemas ...')
        # TODO: a big code for generating schemas on-the-fly
        task_dls_dict = {}
        for task_name in eval_tasks:
            task_dls_dict[task_name] = build_task_dataloaders(
                task_name=task_name,
                schema_dataset=schema_dataset,
                compq_mt_model=compq_mt_model,
                optm_batch_size=args.optm_batch_size,
                eval_batch_size=args.eval_batch_size,
                el_use_type=el_use_type,
                neg_f1_ths=args.neg_f1_ths,
                neg_max_sample=args.neg_max_sample
            )  # [task_optm_dl, task_eval_train_dl, ...]
        el_dl_list = task_dls_dict.get('el')
        rm_dl_list = task_dls_dict.get('rm')
        full_dl_list = task_dls_dict.get(
            'full')  # these variables could be None
        LogInfo.end_track()

        if not args.test_only:  # won't perform training when just testing
            """ ==== Sub-task optimizing ==== """
            if epoch <= args.pre_train_steps or not full_back_prop:
                # pre-train stage, or always need train & update
                LogInfo.begin_track('Multi-task optimizing ... ')
                optm_schedule_list = []
                if 'el' in optm_tasks:
                    el_optimizer.reset_optm_info()
                    optm_schedule_list += [
                        ('el', x) for x in range(el_dl_list[0].n_batch)
                    ]
                    LogInfo.logs('[ el]: n_rows = %d, n_batch = %d.',
                                 len(el_dl_list[0]), el_dl_list[0].n_batch)
                if 'rm' in optm_tasks:
                    rm_optimizer.reset_optm_info()
                    optm_schedule_list += [
                        ('rm', x) for x in range(rm_dl_list[0].n_batch)
                    ]
                    LogInfo.logs('[ rm]: n_rows = %d, n_batch = %d.',
                                 len(rm_dl_list[0]), rm_dl_list[0].n_batch)
                np.random.shuffle(optm_schedule_list)
                LogInfo.logs('EL & RM task shuffled.')

                for task_name, batch_idx in optm_schedule_list:
                    if task_name == 'el':
                        el_optimizer.optimize(optm_dl=el_dl_list[0],
                                              batch_idx=batch_idx)
                    elif task_name == 'rm':
                        rm_optimizer.optimize(optm_dl=rm_dl_list[0],
                                              batch_idx=batch_idx)

                if 'el' in optm_tasks:
                    LogInfo.logs('[ el] loss = %.6f', el_optimizer.ret_loss)
                    disp_item_dict['el_loss'] = el_optimizer.ret_loss
                if 'rm' in optm_tasks:
                    LogInfo.logs('[ rm] loss = %.6f', rm_optimizer.ret_loss)
                    disp_item_dict['rm_loss'] = rm_optimizer.ret_loss
                LogInfo.end_track()  # End of optm.
        """ ==== Sub-task evluation, if possible ==== """
        if epoch <= args.pre_train_steps or not full_back_prop:
            for task, task_dl_list, evaluator in [
                ('el', el_dl_list, el_evaluator),
                ('rm', rm_dl_list, rm_evaluator)
            ]:
                if task not in eval_tasks:
                    continue
                LogInfo.begin_track('Evaluation for [%s]:', task)
                for mark, eval_dl in zip('Tvt', task_dl_list[1:]):
                    LogInfo.begin_track('Eval-%s ...', mark)
                    disp_key = '%s_%s_F1' % (task, mark)
                    detail_fp = '%s/detail/%s.%s.tmp' % (output_dir, task, mark
                                                         )  # detail/rm.T.tmp
                    disp_item_dict[disp_key] = evaluator.evaluate_all(
                        eval_dl=eval_dl, detail_fp=detail_fp)
                    LogInfo.end_track()
                LogInfo.end_track()
        """ ==== full optimization & evaluation, also prepare data for ltr ==== """
        if epoch > args.pre_train_steps or not full_back_prop:
            pyltr_data_list = [
            ]  # save T/v/t <q, [cand]> formation for the use of pyltr
            if 'full' in eval_tasks:
                LogInfo.begin_track('Full-task Optm & Eval:')
                if 'full' in optm_tasks:
                    LogInfo.begin_track('Optimizing ...')
                    LogInfo.logs('[full]: n_rows = %d, n_batch = %d.',
                                 len(full_dl_list[0]), full_dl_list[0].n_batch)
                    full_optimizer.optimize_all(
                        optm_dl=full_dl_list[0]
                    )  # quickly optimize the full model
                    LogInfo.logs('[full] loss = %.6f', full_optimizer.ret_loss)
                    disp_item_dict['full_loss'] = full_optimizer.ret_loss
                    LogInfo.end_track()
                for mark, eval_dl in zip('Tvt', full_dl_list[1:]):
                    LogInfo.begin_track('Eval-%s ...', mark)
                    disp_key = 'full_%s_F1' % mark
                    detail_fp = '%s/detail/full.%s.tmp' % (output_dir, mark)
                    disp_item_dict[disp_key] = full_evaluator.evaluate_all(
                        eval_dl=eval_dl, detail_fp=detail_fp)
                    pyltr_data_list.append(full_evaluator.ret_q_score_dict)
                    LogInfo.end_track()
                LogInfo.end_track()
        """ ==== LTR optimization & evaluation (non-TF code) ==== """
        if 'ltr' in eval_tasks:
            LogInfo.begin_track('LTR Optm & Eval:')
            assert len(pyltr_data_list) == 3
            LogInfo.logs('rich_feats_concat collected for all T/v/t schemas.')
            LogInfo.begin_track('Ready for ltr running ... ')
            ltr_metric_list = ltr_whole_process(
                pyltr_data_list=pyltr_data_list,
                eval_dl_list=full_dl_list[1:],
                output_dir=output_dir)
            LogInfo.end_track()
            for mark_idx, mark in enumerate(['T', 'v', 't']):
                key = 'ltr_%s_F1' % mark
                disp_item_dict[key] = ltr_metric_list[mark_idx]
            LogInfo.end_track()

        if not args.test_only:
            """ Display & save states (results, details, params) """
            validate_focus = '%s_v_F1' % full_optm_method
            if validate_focus in disp_item_dict:
                cur_valid_f1 = disp_item_dict[validate_focus]
                if cur_valid_f1 > best_valid_f1:
                    best_valid_f1 = cur_valid_f1
                    update_flag = True
                    patience = args.max_patience
                else:
                    patience -= 1
                LogInfo.logs('Model %s, best %s = %.6f [patience = %d]',
                             'updated' if update_flag else 'stayed',
                             validate_focus, cur_valid_f1, patience)
                disp_item_dict['Status'] = 'UPDATE' if update_flag else str(
                    patience)
            else:
                disp_item_dict['Status'] = '------'

            disp_item_dict['Time'] = datetime.now().strftime(
                '%Y-%m-%d_%H:%M:%S')
            with open(status_fp, 'a') as bw:
                write_str = ''
                for item_idx, header in enumerate(disp_header_list):
                    if header.endswith(' ') or header == '\t':  # just a split
                        write_str += header
                    else:
                        val = disp_item_dict.get(header, '--------')
                        if isinstance(val, float):
                            write_str += '%8.6f' % val
                        else:
                            write_str += str(val)
                bw.write(write_str + '\n')

            LogInfo.logs('Output concern parameters ... ')
            param_result_list = sess.run(
                focus_param_list
            )  # don't need any feeds, since we focus on parameters
            with open(output_dir + '/detail/param.%03d' % epoch, 'w') as bw:
                LogInfo.redirect(bw)
                np.set_printoptions(threshold=np.nan)
                for param_name, param_result in zip(focus_param_name_list,
                                                    param_result_list):
                    LogInfo.logs('%s: shape = %s ', param_name,
                                 param_result.shape)
                    LogInfo.logs(param_result)
                    LogInfo.logs('============================\n')
                np.set_printoptions()
                LogInfo.stop_redirect()

            if update_flag:  # save the latest details
                for mode in 'Tvt':
                    for task in ('rm', 'el', 'full', 'ltr'):
                        src = '%s/detail/%s.%s.tmp' % (output_dir, task, mode)
                        dest = '%s/detail/%s.%s.best' % (output_dir, task,
                                                         mode)
                        if os.path.isfile(src):
                            shutil.move(src, dest)
                if args.save_best:
                    save_best_dir = '%s/model_best' % output_dir
                    delete_dir(save_best_dir)
                    save_model(saver=saver,
                               sess=sess,
                               model_dir=save_best_dir,
                               epoch=epoch,
                               valid_metric=best_valid_f1)
        else:
            """ Output rank information for all schemas """
            LogInfo.begin_track(
                'Saving overall testing results between Q-words and main paths: '
            )
            mp_qw_dict = {}
            qw_mp_dict = {}
            qa_list = schema_dataset.qa_list
            count = 0
            for q_idx, sc_list in schema_dataset.smart_q_cand_dict.items():
                qa = qa_list[q_idx]
                raw_tok_list = [tok.token.lower() for tok in qa['tokens']]
                for sc in sc_list:
                    if sc.run_info is None or 'rm_score' not in sc.run_info:
                        continue
                    count += 1
                    score = sc.run_info['rm_score']
                    main_path = sc.main_pred_seq
                    main_path_str = '-->'.join(main_path)
                    main_path_words = sc.path_words_list[0]
                    main_path_words_str = ' | '.join(main_path_words)
                    mp = '[%s] [%s]' % (main_path_str, main_path_words_str)
                    rm_tok_list = RelationMatchingEvaluator.prepare_rm_tok_list(
                        sc=sc, raw_tok_list=raw_tok_list)
                    rm_tok_str = ' '.join(rm_tok_list).replace('<PAD>',
                                                               '').strip()
                    qw = 'Q-%04d [%s]' % (q_idx, rm_tok_str)
                    mp_qw_dict.setdefault(mp, []).append((qw, score))
                    qw_mp_dict.setdefault(qw, []).append((mp, score))
            LogInfo.logs('%d qw + %d path --> %d pairs ready to save.',
                         len(qw_mp_dict), len(mp_qw_dict), count)

            with codecs.open(output_dir + '/detail/mp_qw_results.txt', 'w',
                             'utf-8') as bw:
                mp_list = sorted(mp_qw_dict.keys())
                for mp in mp_list:
                    qw_score_tups = mp_qw_dict[mp]
                    bw.write(mp + '\n')
                    qw_score_tups.sort(key=lambda _tup: _tup[-1], reverse=True)
                    for rank_idx, (qw,
                                   score) in enumerate(qw_score_tups[:100]):
                        bw.write('  Rank=%04d    score=%9.6f    %s\n' %
                                 (rank_idx + 1, score, qw))
                    bw.write('\n===================================\n\n')
            LogInfo.logs('<path, qwords, score> saved.')

            with codecs.open(output_dir + '/detail/qw_mp_results.txt', 'w',
                             'utf-8') as bw:
                qw_list = sorted(qw_mp_dict.keys())
                for qw in qw_list:
                    mp_score_tups = qw_mp_dict[qw]
                    bw.write(qw + '\n')
                    mp_score_tups.sort(key=lambda _tup: _tup[-1], reverse=True)
                    for rank_idx, (mp,
                                   score) in enumerate(mp_score_tups[:100]):
                        bw.write('  Rank=%04d    score=%9.6f    %s\n' %
                                 (rank_idx + 1, score, mp))
                    bw.write('\n===================================\n\n')
            LogInfo.logs('<qwords, path, score> saved.')

            LogInfo.end_track()

            LogInfo.end_track()  # jump out of the epoch iteration
            break  # test-only mode, no need to perform more things.

        LogInfo.end_track()  # end of epoch
    LogInfo.end_track()  # end of learning
Пример #5
0
def work(exp_dir, data_dir, best_epoch, qa_list, yih_ret_dict):
    log_fp = '%s/yih_compare_%03d.txt' % (exp_dir, best_epoch)

    pick_sc_dict = {q_idx: (-1, 0.) for q_idx in range(3778, 5810)}
    ret_fp = '%s/result/full.t.%03d' % (exp_dir, best_epoch)
    with open(ret_fp, 'r') as br:
        for line in br.readlines():
            spt = line.strip().split('\t')
            q_idx = int(spt[0])
            line_no = int(spt[1])
            ours_f1 = float(spt[2])
            pick_sc_dict[q_idx] = (line_no, ours_f1)

    disc = Discretizer([-0.99, -0.50, -0.25, -0.01, 0.01, 0.25, 0.50, 0.99])
    delta_tup_list = []
    avg_yih_f1 = 0.
    avg_ours_f1 = 0.
    for q_idx in range(3778, 5810):
        qa = qa_list[q_idx]
        q = qa['utterance']
        gold_answer_list = qa['targetValue']
        yih_answer_list = json.loads(yih_ret_dict[q])
        _, _, yih_f1 = compute_f1(goldList=gold_answer_list,
                                  predictedList=yih_answer_list)
        ours_f1 = pick_sc_dict[q_idx][1]
        avg_yih_f1 += yih_f1
        avg_ours_f1 += ours_f1
        delta = ours_f1 - yih_f1
        disc.convert(delta)
        delta_tup_list.append((q_idx, delta))
    avg_yih_f1 /= 2032
    avg_ours_f1 /= 2032

    delta_tup_list.sort(key=lambda _tup: _tup[1])
    LogInfo.logs('%d questions delta sorted.', len(delta_tup_list))

    total_size = len(delta_tup_list)
    worse_size = len(filter(lambda _tup: _tup[1] < 0., delta_tup_list))
    better_size = len(filter(lambda _tup: _tup[1] > 0., delta_tup_list))
    equal_size = total_size - worse_size - better_size

    bw = codecs.open(log_fp, 'w', 'utf-8')
    LogInfo.redirect(bw)
    LogInfo.logs('Avg_Yih_F1 = %.6f, Avg_Ours_F1 = %.6f', avg_yih_f1,
                 avg_ours_f1)
    LogInfo.logs(' Worse cases = %d (%.2f%%)', worse_size,
                 100. * worse_size / total_size)
    LogInfo.logs(' Equal cases = %d (%.2f%%)', equal_size,
                 100. * equal_size / total_size)
    LogInfo.logs('Better cases = %d (%.2f%%)', better_size,
                 100. * better_size / total_size)
    disc.show_distribution()
    LogInfo.logs()
    for q_idx, _ in delta_tup_list:
        qa = qa_list[q_idx]
        line_no, ours_f1 = pick_sc_dict[q_idx]
        q = qa['utterance']
        yih_answer_list = json.loads(yih_ret_dict[q])
        if line_no == -1:
            continue
        single_question(q_idx=q_idx,
                        qa=qa,
                        data_dir=data_dir,
                        line_no=line_no,
                        yih_answer_list=yih_answer_list,
                        ours_f1=ours_f1)
    LogInfo.stop_redirect()
Пример #6
0
def general_evaluate(eval_model,
                     data_loader,
                     epoch_idx,
                     ob_batch_num=20,
                     detail_fp=None,
                     result_fp=None,
                     summary_writer=None):
    """
    First evaluate all batches, and then count the final score via out-of-TF codes
    """
    assert isinstance(eval_model, BaseModel)
    if data_loader is None or len(data_loader) == 0:  # empty eval data
        return 0.
    """ Step 1: Run each batch, calculate case-specific results """

    eval_name_list = [tup[0] for tup in eval_model.eval_output_tf_tup_list]
    eval_tensor_list = [tup[1] for tup in eval_model.eval_output_tf_tup_list]
    assert isinstance(data_loader, DataLoader)
    eval_model.prepare_data(data_loader=data_loader)
    run_options = run_metadata = None
    if summary_writer is not None:
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
    scan_size = 0
    ret_q_score_dict = {}  # <q, [(schema, score)]>
    for batch_idx in range(data_loader.n_batch):
        local_data_list, local_indices = data_loader.get_next_batch()
        local_size = len(
            local_data_list[0])  # the first dimension is always batch size
        fd = {
            input_tf: local_data
            for input_tf, local_data in zip(eval_model.eval_input_tf_list,
                                            local_data_list)
        }

        # Dynamically evaluate all the concerned tensors, no hard coding any more
        eval_result_list = eval_model.sess.run(eval_tensor_list,
                                               feed_dict=fd,
                                               options=run_options,
                                               run_metadata=run_metadata)
        scan_size += local_size
        if (batch_idx + 1) % ob_batch_num == 0:
            LogInfo.logs('[eval-%s-B%d/%d] scanned = %d/%d', data_loader.mode,
                         batch_idx + 1, data_loader.n_batch, scan_size,
                         len(data_loader))
        if summary_writer is not None and batch_idx == 0:
            summary_writer.add_run_metadata(run_metadata,
                                            'epoch-%d' % epoch_idx)

        for pos in range(
                len(local_indices)
        ):  # enumerate each data in the batch, and record corresponding scores
            local_idx = local_indices[pos]
            q_idx, cand = data_loader.cand_tup_list[local_idx]
            cand.run_info = {}
            for eval_name, eval_result in zip(eval_name_list,
                                              eval_result_list):
                cur_result = eval_result[pos]
                cand.run_info[eval_name] = cur_result
            ret_q_score_dict.setdefault(q_idx, []).append(cand)
    """ Step 2: After scanning all the batch, now count the final F1 result """

    f1_list = []
    for q_idx, score_list in ret_q_score_dict.items():
        score_list.sort(key=lambda x: x.run_info['score'],
                        reverse=True)  # sort by score DESC
        if len(score_list) == 0:
            f1_list.append(0.)
        else:
            f1_list.append(
                score_list[0].f1)  # pick the f1 of the highest scored schema
    LogInfo.logs('Predict %d out of %d questions.', len(f1_list),
                 data_loader.question_size)
    ret_metric = np.sum(f1_list) / data_loader.question_size
    """ Step 3: Got non-case-specific results, that are, parameters """

    param_name_list = [tup[0] for tup in eval_model.show_param_tf_tup_list]
    param_tensor_list = [tup[1] for tup in eval_model.show_param_tf_tup_list]
    param_result_list = eval_model.sess.run(
        param_tensor_list
    )  # don't need any feeds, since we focus on parameters
    """ Step 4: Save detail information: Schema Results & Parameters """

    if result_fp is not None:
        srt_q_idx_list = sorted(ret_q_score_dict.keys())
        with open(result_fp, 'w') as bw:  # write question --> selected schema
            for q_idx in srt_q_idx_list:
                srt_list = ret_q_score_dict[q_idx]
                ori_idx = -1
                f1 = 0.
                if len(srt_list) > 0:
                    best_sc = srt_list[0]
                    ori_idx = best_sc.ori_idx
                    f1 = best_sc.f1
                bw.write('%d\t%d\t%.6f\n' % (q_idx, ori_idx, f1))

    if detail_fp is not None:
        bw = open(detail_fp, 'w')
        LogInfo.redirect(bw)
        np.set_printoptions(threshold=np.nan)

        LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
        srt_q_idx_list = sorted(ret_q_score_dict.keys())
        for q_idx in srt_q_idx_list:
            q = data_loader.dataset.qa_list[q_idx]['utterance']
            LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8'))

            srt_list = ret_q_score_dict[q_idx]  # already sorted
            best_label_f1 = np.max([sc.f1 for sc in srt_list])
            best_label_f1 = max(best_label_f1, 0.000001)
            for rank, sc in enumerate(srt_list):
                if rank < 20 or sc.f1 == best_label_f1:
                    LogInfo.begin_track(
                        '#-%04d [F1 = %.6f] [row_in_file = %d]', rank + 1,
                        sc.f1, sc.ori_idx)
                    for eval_name in eval_name_list:
                        val = sc.run_info[eval_name]
                        if isinstance(val, float) or isinstance(
                                val,
                                np.float32):  # displaying a single float value
                            LogInfo.logs('%16s: %9.6f', eval_name,
                                         sc.run_info[eval_name])
                        elif isinstance(val, np.ndarray):
                            if 'att_mat' in eval_name:  # displaying several attention matrices
                                disp_mat_list = att_matrix_auto_crop(
                                    att_mat=val)
                                for idx, crop_mat in enumerate(disp_mat_list):
                                    if np.prod(crop_mat.shape) == 0:
                                        continue  # skip attention matrix of padding paths
                                    LogInfo.begin_track(
                                        '%16s: shape = %s',
                                        '%s-%d' % (eval_name, idx),
                                        crop_mat.shape)
                                    LogInfo.logs(
                                        crop_mat
                                    )  # output attention matrix one-by-one
                                    LogInfo.end_track()
                            else:  # displaying a single ndarray
                                LogInfo.begin_track('%16s: shape = %s',
                                                    eval_name, val.shape)
                                LogInfo.logs(val)
                                LogInfo.end_track()
                        else:
                            LogInfo.logs('%16s: illegal value %s', eval_name,
                                         type(val))
                    for path_idx, path in enumerate(sc.path_list):
                        LogInfo.logs('Path-%d: [%s]', path_idx,
                                     '-->'.join(path))
                    for path_idx, words in enumerate(sc.path_words_list):
                        LogInfo.logs('Path-Word-%d: [%s]', path_idx,
                                     ' | '.join(words).encode('utf-8'))
                    LogInfo.end_track()
            LogInfo.end_track()
        LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)

        LogInfo.logs('=================== Parameters ===================')
        for param_name, param_result in zip(param_name_list,
                                            param_result_list):
            LogInfo.begin_track('%s: shape = %s ', param_name,
                                param_result.shape)
            LogInfo.logs(param_result)
            LogInfo.end_track()

        np.set_printoptions()  # reset output format
        LogInfo.stop_redirect()
        bw.close()
    return ret_metric
Пример #7
0
    def evaluate(self,
                 data_loader,
                 epoch_idx,
                 ob_batch_num=20,
                 detail_fp=None,
                 result_fp=None,
                 summary_writer=None):
        if data_loader is None or len(data_loader) == 0:  # empty eval data
            return 0.

        assert isinstance(data_loader, CompqSingleDataLoader)
        self.prepare_data(data_loader=data_loader)
        run_options = run_metadata = None
        if summary_writer is not None:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        scan_size = 0
        ret_q_score_dict = {}  # <q, [(schema, score)]>
        for batch_idx in range(data_loader.n_batch):
            local_data_list, local_indices = data_loader.get_next_batch()
            local_size = len(
                local_data_list[0])  # the first dimension is always batch size
            fd = {
                input_tf: local_data
                for input_tf, local_data in zip(self.eval_input_tf_list,
                                                local_data_list)
            }

            # score_mat, sc_final_rep, att_tensor, q_weight, sc_weight, summary = \
            # local_score_list, summary = self.sess.run(
            local_score_list = self.sess.run(
                # [self.score_mat, self.sc_final_rep, self.att_tensor,
                #  self.q_weight, self.sc_weight, self.eval_summary],
                # [self.score, self.eval_summary],
                self.score,
                feed_dict=fd,
                options=run_options,
                run_metadata=run_metadata)
            scan_size += local_size
            if (batch_idx + 1) % ob_batch_num == 0:
                LogInfo.logs('[eval-%s-B%d/%d] scanned = %d/%d',
                             data_loader.mode, batch_idx + 1,
                             data_loader.n_batch, scan_size, len(data_loader))

            if summary_writer is not None:
                if batch_idx == 0:
                    summary_writer.add_run_metadata(run_metadata,
                                                    'epoch-%d' % epoch_idx)

            # for local_idx, score_vec, sc_mat, local_att_tensor, local_q_weight, local_sc_weight in zip(
            #         local_indices, score_mat, sc_final_rep, att_tensor, q_weight, sc_weight):  # enumerate each row
            for local_idx, score in zip(local_indices, local_score_list):
                q_idx, cand = data_loader.cand_tup_list[local_idx]
                assert isinstance(cand, CompqSchema)
                score_list = ret_q_score_dict.setdefault(
                    q_idx, [])  # save all candidates with their scores
                cand.run_info = {
                    'score': score,
                    # 'sc_vec': sc_vec,
                    # 'att_mat': att_mat,
                    # 'q_weight_vec': q_weight_vec,
                    # 'sc_weight_vec': sc_weight_vec
                }  # save detail information within the cand.
                score_list.append(cand)
                assert isinstance(cand, CompqSchema)

        # After scanning all the batch, now count the final F1 result
        f1_list = []
        for q_idx, score_list in ret_q_score_dict.items():
            score_list.sort(key=lambda x: x.run_info['score'],
                            reverse=True)  # sort by score DESC
            if len(score_list) == 0:
                f1_list.append(0.)
            else:
                f1_list.append(score_list[0].f1
                               )  # pick the f1 of the highest scored schema
        LogInfo.logs('Predict %d out of %d questions.', len(f1_list),
                     data_loader.question_size)
        ret_metric = np.sum(f1_list) / data_loader.question_size

        if result_fp is not None:
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            with open(result_fp,
                      'w') as bw:  # write question --> selected schema
                for q_idx in srt_q_idx_list:
                    srt_list = ret_q_score_dict[q_idx]
                    ori_idx = -1
                    f1 = 0.
                    if len(srt_list) > 0:
                        best_sc = srt_list[0]
                        ori_idx = best_sc.ori_idx
                        f1 = best_sc.f1
                    bw.write('%d\t%d\t%.6f\n' % (q_idx, ori_idx, f1))

        if detail_fp is not None:
            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                LogInfo.begin_track(
                    'Q-%04d [%s]:', q_idx,
                    data_loader.dataset.q_list[q_idx].encode('utf-8'))
                # q_words = data_loader.dataset.q_words_dict[q_idx]
                # word_surf_list = map(lambda x: w_uhash[x].encode('utf-8'), q_words)

                srt_list = ret_q_score_dict[q_idx]  # already sorted
                best_label_f1 = np.max([sc.f1 for sc in srt_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, sc in enumerate(srt_list):
                    # path_sz = len(sc.path_list)
                    # path_surf_list = map(lambda x: 'Path-%d' % x, range(path_sz))
                    if rank < 5 or sc.f1 == best_label_f1:
                        LogInfo.begin_track('#-%04d: F1=%.6f, score=%9.6f ',
                                            rank + 1, sc.f1,
                                            sc.run_info['score'])
                        for path_idx, path in enumerate(sc.path_list):
                            LogInfo.logs('Path-%d: [%s]', path_idx,
                                         '-->'.join(path))
                        for path_idx, words in enumerate(sc.path_words_list):
                            LogInfo.logs('Path-Word-%d: [%s]', path_idx,
                                         ' | '.join(words).encode('utf-8'))
                        # LogInfo.logs('Raw Attention Score:')
                        # self.print_att_matrix(word_surf_list=word_surf_list,
                        #                       path_surf_list=path_surf_list,
                        #                       att_mat=sc.run_info['att_mat'])
                        # LogInfo.logs('Q-side Attention Weight:')
                        # self.print_weight_vec(item_surf_list=word_surf_list,
                        #                       weight_vec=sc.run_info['q_weight_vec'])
                        # LogInfo.logs('SC-side Attention Weight:')
                        # self.print_weight_vec(item_surf_list=path_surf_list,
                        #                       weight_vec=sc.run_info['sc_weight_vec'])
                        LogInfo.end_track()
                LogInfo.end_track()
            LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
            LogInfo.stop_redirect()
            bw.close()
        return ret_metric
Пример #8
0
    def post_process(self, eval_dl, detail_fp, result_fp):
        if len(self.eval_detail_dict) == 0:
            self.eval_detail_dict = {k: [] for k in self.concern_name_list}

        ret_q_score_dict = {}
        for scan_idx, (q_idx, cand) in enumerate(eval_dl.eval_sc_tup_list):
            cand.run_info = {
                k: self.eval_detail_dict[k][scan_idx]
                for k in self.concern_name_list
            }
            ret_q_score_dict.setdefault(q_idx, []).append(cand)
            # put all output results into sc.run_info

        rm_f1_list = []
        for q_idx, score_list in ret_q_score_dict.items():
            score_list.sort(key=lambda x: x.run_info['rm_score'],
                            reverse=True)  # sort by score DESC
            if len(score_list) == 0:
                rm_f1_list.append(0.)
            else:
                rm_f1_list.append(score_list[0].rm_f1)
        LogInfo.logs('[%3s] Predict %d out of %d questions.', self.name,
                     len(rm_f1_list), eval_dl.total_questions)
        ret_metric = np.sum(rm_f1_list).astype(
            'float32') / eval_dl.total_questions
        """ Save detail information """
        if result_fp is not None:
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            with open(result_fp,
                      'w') as bw:  # write question --> selected schema
                for q_idx in srt_q_idx_list:
                    srt_list = ret_q_score_dict[q_idx]
                    ori_idx = -1
                    rm_f1 = 0.
                    if len(srt_list) > 0:
                        best_sc = srt_list[0]
                        ori_idx = best_sc.ori_idx
                        rm_f1 = best_sc.rm_f1
                    bw.write('%d\t%d\t%.6f\n' % (q_idx, ori_idx, rm_f1))

        if detail_fp is not None:
            # use_p = self.compq_mt_model.rm_kernel.use_p
            # use_pw = self.compq_mt_model.rm_kernel.use_pw
            # path_usage = self.compq_mt_model.rm_kernel.path_usage
            # pw_max_len = self.compq_mt_model.pword_max_len
            # p_max_len = self.compq_mt_model.path_max_len
            schema_dataset = eval_dl.schema_dataset
            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            np.set_printoptions(threshold=np.nan)
            LogInfo.logs('Avg_rm_f1 = %.6f', ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                qa = schema_dataset.qa_list[q_idx]
                q = qa['utterance']
                LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8'))

                srt_list = ret_q_score_dict[q_idx]  # already sorted
                best_label_f1 = np.max([sc.rm_f1 for sc in srt_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, sc in enumerate(srt_list):
                    if rank < 20 or sc.rm_f1 == best_label_f1:
                        LogInfo.begin_track(
                            '#-%04d [rm_F1 = %.6f] [row_in_file = %d]',
                            rank + 1, sc.rm_f1, sc.ori_idx)
                        LogInfo.logs('rm_score: %.6f', sc.run_info['rm_score'])
                        # self.show_att_mat(sc=sc, qa=qa)
                        LogInfo.logs('Current: not output detail.')
                        LogInfo.end_track()
                LogInfo.end_track()
            LogInfo.logs('Avg_rm_f1 = %.6f', ret_metric)

            # LogInfo.logs('=================== Parameters ===================')
            # for param_name, param_result in zip(param_name_list, param_result_list):
            #     LogInfo.begin_track('%s: shape = %s ', param_name, param_result.shape)
            #     LogInfo.logs(param_result)
            #     LogInfo.end_track()

            np.set_printoptions()  # reset output format
            LogInfo.stop_redirect()
            bw.close()

        return [ret_metric]
Пример #9
0
def working_in_data(data_dir, file_list, qa_list, sc_max_len=3):

    save_fp = data_dir + '/log.schema_check'
    bw = open(save_fp, 'w')
    LogInfo.redirect(bw)

    LogInfo.begin_track('Working in %s, with schemas in %s:', data_dir,
                        file_list)
    with open(data_dir + '/' + file_list, 'r') as br:
        lines = br.readlines()

    used_pred_tup_list = []
    stat_tup_list = []
    for line_idx, line in enumerate(lines):
        if line_idx % 100 == 0:
            LogInfo.logs('Scanning %d / %d ...', line_idx, len(lines))
        line = line.strip()
        q_idx = int(line.split('/')[2].split('_')[0])
        # if q_idx != 1353:
        #     continue
        schema_fp = data_dir + '/' + line
        link_fp = schema_fp.replace('schema', 'links')
        Tt.start('read_linkings')
        gather_linkings = []
        with codecs.open(link_fp, 'r', 'utf-8') as br:
            for gl_line in br.readlines():
                tup_list = json.loads(gl_line.strip())
                ld_dict = {k: v for k, v in tup_list}
                gather_linkings.append(LinkData(**ld_dict))
        Tt.record('read_linkings')
        Tt.start('read_schema')
        global_lists, used_pred_name_dict = \
            read_schemas_from_single_file(q_idx, schema_fp, gather_linkings, sc_max_len)
        Tt.record('read_schema')
        stat_tup_list.append((q_idx, global_lists))
        used_pred_tup_list.append((q_idx, used_pred_name_dict))

    stat_tup_list.sort(
        key=lambda _tup: max_f1(_tup[1][ELEGANT]) - max_f1(_tup[1][STRICT]),
        reverse=True)

    LogInfo.logs('sc_len distribution: %s', sc_len_list)

    LogInfo.logs('Rank\tQ_idx\tstri.\teleg.\tcohe.\tgene.'
                 '\tstri._F1\teleg._F1\tcohe._F1\tgene._F1\tUtterance')
    for rank, stat_tup in enumerate(stat_tup_list):
        q_idx, global_lists = stat_tup
        size_list = ['%4d' % len(cand_list) for cand_list in global_lists]
        max_f1_list = [
            '%8.6f' % max_f1(cand_list) for cand_list in global_lists
        ]
        size_str = '\t'.join(size_list)
        max_f1_str = '\t'.join(max_f1_list)
        LogInfo.logs('%4d\t%4d\t%s\t%s\t%s', rank + 1, q_idx, size_str,
                     max_f1_str, qa_list[q_idx]['utterance'].encode('utf-8'))

    q_size = len(stat_tup_list)
    f1_upperbound_list = []
    avg_cand_size_list = []
    found_entity_list = []
    for index in range(4):
        local_max_f1_list = []
        local_cand_size_list = []
        for _, global_lists in stat_tup_list:
            sc_list = global_lists[index]
            local_max_f1_list.append(max_f1(sc_list))
            local_cand_size_list.append(len(sc_list))
        local_found_entity = len(filter(lambda y: y > 0., local_max_f1_list))
        f1_upperbound_list.append(1.0 * sum(local_max_f1_list) / q_size)
        avg_cand_size_list.append(1.0 * sum(local_cand_size_list) / q_size)
        found_entity_list.append(local_found_entity)

    LogInfo.logs('  strict: avg size = %.6f, upp F1 = %.6f',
                 avg_cand_size_list[STRICT], f1_upperbound_list[STRICT])
    for name, index in zip(('strict', 'elegant', 'coherent', 'general'),
                           (STRICT, ELEGANT, COHERENT, GENERAL)):
        avg_size = avg_cand_size_list[index]
        ratio = 1. * avg_size / avg_cand_size_list[
            STRICT] if avg_cand_size_list[STRICT] > 0 else 0.
        found_entity = found_entity_list[index]
        upp_f1 = f1_upperbound_list[index]
        gain = upp_f1 - f1_upperbound_list[STRICT]
        LogInfo.logs(
            '%8s: avg size = %.6f (%.2fx), found entity = %d, '
            'upp F1 = %.6f, gain = %.6f', name, avg_size, ratio, found_entity,
            upp_f1, gain)
    LogInfo.end_track()

    bw.close()
    LogInfo.stop_redirect()

    used_pred_fp = data_dir + '/log.used_pred_stat'
    bw = codecs.open(used_pred_fp, 'w', 'utf-8')
    LogInfo.redirect(bw)
    ratio_list = []
    for q_idx, used_pred_name_dict in used_pred_tup_list:
        unique_id_size = len(set(used_pred_name_dict.keys()))
        unique_name_size = len(set(used_pred_name_dict.values()))
        ratio = 1. * unique_name_size / unique_id_size if unique_id_size > 0 else 1.
        ratio_list.append(ratio)
        LogInfo.logs('Q = %4d, unique id = %d, unique name = %d, ratio = %.4f',
                     q_idx, unique_id_size, unique_name_size, ratio)
    avg_ratio = sum(ratio_list) / len(ratio_list)
    LogInfo.logs('avg_ratio = %.4f', avg_ratio)
    LogInfo.stop_redirect()
Пример #10
0
    def evaluate(self, data_loader, epoch_idx, detail_fp=None, summary_writer=None):
        if data_loader is None:
            return 0.

        assert isinstance(data_loader, QScEvalDataLoader)
        scan_size = 0
        ret_metric = 0.
        if data_loader.dynamic or data_loader.np_data_list is None:
            data_loader.renew_data_list()
        if self.verbose > 0:
            LogInfo.logs('num of batch = %d.', data_loader.n_batch)

        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        ret_q_score_dict = {}  # <q, [(schema, score)]>
        for batch_idx in range(data_loader.n_batch):
            point = (epoch_idx - 1) * data_loader.n_batch + batch_idx
            local_data_list, local_indices = data_loader.get_next_batch()
            local_size = len(local_data_list[0])    # the first dimension is always batch size
            fd = {input_tf: local_data for input_tf, local_data in zip(self.eval_input_tf_list, local_data_list)}

            if summary_writer is None:
                score_mat, local_metric = self.sess.run([self.score_mat, self.eval_metric_val], feed_dict=fd)
                summary = None
            else:
                score_mat, local_metric, summary = self.sess.run(
                    [self.score_mat, self.eval_metric_val, self.eval_summary],
                    feed_dict=fd, options=run_options, run_metadata=run_metadata)
            ret_metric = (ret_metric * scan_size + local_metric * local_size) / (scan_size + local_size)
            scan_size += local_size
            if (batch_idx+1) % self.ob_batch_num == 0:
                LogInfo.logs('[eval-%s-B%d/%d] metric = %.6f, scanned = %d/%d',
                             data_loader.mode,
                             batch_idx+1,
                             data_loader.n_batch,
                             ret_metric,
                             scan_size,
                             len(data_loader))
            if summary_writer is not None:
                summary_writer.add_summary(summary, point)
                if batch_idx == 0:
                    summary_writer.add_run_metadata(run_metadata, 'epoch-%d' % epoch_idx)

            for local_idx, score_vec in zip(local_indices, score_mat):      # enumerate each row
                q_idx, cands = data_loader.q_cands_tup_list[local_idx]
                score_tup_list = ret_q_score_dict.setdefault(q_idx, [])
                for cand_idx, cand in enumerate(cands):  # enumerate each candidate
                    score = score_vec[cand_idx]
                    score_tup_list.append((cand, score))

        if detail_fp is not None:
            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                LogInfo.begin_track('Q-%04d [%s]:', q_idx,
                                    data_loader.dataset.webq_list[q_idx].encode('utf-8'))
                score_tup_list = ret_q_score_dict[q_idx]
                score_tup_list.sort(key=lambda x: x[1], reverse=True)
                best_label_f1 = np.max([tup[0].f1 for tup in score_tup_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, tup in enumerate(score_tup_list):
                    schema, score = tup
                    if rank < 5 or schema.f1 == best_label_f1:
                        LogInfo.logs('%4d: F1=%.6f, score=%9.6f, schema=[%s]',
                                     rank+1, schema.f1, score, schema.path_list_str)
                        schema.display_embedding_info()
                LogInfo.end_track()
            LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
            LogInfo.stop_redirect()
            bw.close()

        return ret_metric
Пример #11
0
    def post_process(self, eval_dl, detail_fp, result_fp):
        assert len(self.eval_detail_dict) > 0

        ret_q_score_dict = {}
        for scan_idx, (q_idx, cand) in enumerate(eval_dl.eval_sc_tup_list):
            cand.run_info = {k: data_values[scan_idx] for k, data_values in self.eval_detail_dict.items()}
            ret_q_score_dict.setdefault(q_idx, []).append(cand)
            # put all output results into sc.run_info

        score_key = '%s_score' % self.task_name
        f1_key = 'f1' if self.task_name == 'full' else '%s_f1' % self.task_name
        f1_list = []
        for q_idx, score_list in ret_q_score_dict.items():
            score_list.sort(key=lambda x: x.run_info[score_key], reverse=True)  # sort by score DESC
            if len(score_list) == 0:
                f1_list.append(0.)
            else:
                f1_list.append(getattr(score_list[0], f1_key))
        LogInfo.logs('[%3s] Predict %d out of %d questions.', self.task_name, len(f1_list), eval_dl.total_questions)
        ret_metric = np.sum(f1_list).astype('float32') / eval_dl.total_questions

        if detail_fp is not None:
            schema_dataset = eval_dl.schema_dataset
            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            np.set_printoptions(threshold=np.nan)
            LogInfo.logs('Avg_%s_f1 = %.6f', self.task_name, ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                qa = schema_dataset.qa_list[q_idx]
                q = qa['utterance']
                LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8'))
                srt_list = ret_q_score_dict[q_idx]  # already sorted
                best_label_f1 = np.max([getattr(sc, f1_key) for sc in srt_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, sc in enumerate(srt_list):
                    cur_f1 = getattr(sc, f1_key)
                    if rank < 20 or cur_f1 == best_label_f1:
                        LogInfo.begin_track('#-%04d [%s_F1 = %.6f] [row_in_file = %d]',
                                            rank+1, self.task_name, cur_f1, sc.ori_idx)
                        LogInfo.logs('%s: %.6f', score_key, sc.run_info[score_key])
                        if self.detail_disp_func is not None:
                            self.detail_disp_func(sc=sc, qa=qa, schema_dataset=schema_dataset)
                        else:
                            LogInfo.logs('Current: not output detail.')
                        LogInfo.end_track()
                LogInfo.end_track()
            LogInfo.logs('Avg_%s_f1 = %.6f', self.task_name, ret_metric)

            np.set_printoptions()  # reset output format
            LogInfo.stop_redirect()
            bw.close()

        """ Save detail information """
        if result_fp is not None:
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            with open(result_fp, 'w') as bw:  # write question --> selected schema
                for q_idx in srt_q_idx_list:
                    srt_list = ret_q_score_dict[q_idx]
                    ori_idx = -1
                    task_f1 = 0.
                    if len(srt_list) > 0:
                        best_sc = srt_list[0]
                        ori_idx = best_sc.ori_idx
                        task_f1 = getattr(best_sc, f1_key)
                    bw.write('%d\t%d\t%.6f\n' % (q_idx, ori_idx, task_f1))

        return ret_metric
Пример #12
0
    def evaluate(self,
                 data_loader,
                 epoch_idx,
                 ob_batch_num=10,
                 detail_fp=None,
                 summary_writer=None):
        if data_loader is None or len(data_loader) == 0:  # empty eval data
            return 0.

        assert isinstance(data_loader, QScEvalDynamicDataLoader)
        self.prepare_data(data_loader=data_loader)
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        scan_size = 0
        ret_q_score_dict = {}  # <q, [(schema, score)]>
        for batch_idx in range(data_loader.n_batch):
            point = (epoch_idx - 1) * data_loader.n_batch + batch_idx
            local_data_list, local_indices = data_loader.get_next_batch()
            local_size = len(
                local_data_list[0])  # the first dimension is always batch size
            fd = {
                input_tf: local_data
                for input_tf, local_data in zip(self.eval_input_tf_list,
                                                local_data_list)
            }

            score_mat, sc_final_rep, att_tensor, q_weight, sc_weight, summary = \
                self.sess.run([self.score_mat, self.sc_final_rep, self.att_tensor,
                               self.q_weight, self.sc_weight, self.eval_summary],
                              feed_dict=fd,
                              options=run_options,
                              run_metadata=run_metadata)
            scan_size += local_size
            if (batch_idx + 1) % ob_batch_num == 0:
                LogInfo.logs('[eval-%s-B%d/%d] scanned = %d/%d',
                             data_loader.mode, batch_idx + 1,
                             data_loader.n_batch, scan_size, len(data_loader))

            if summary_writer is not None:
                summary_writer.add_summary(summary, point)
                if batch_idx == 0:
                    summary_writer.add_run_metadata(run_metadata,
                                                    'epoch-%d' % epoch_idx)

            for local_idx, score_vec, sc_mat, local_att_tensor, local_q_weight, local_sc_weight in zip(
                    local_indices, score_mat, sc_final_rep, att_tensor,
                    q_weight, sc_weight):  # enumerate each row
                q_idx, cands = data_loader.q_cands_tup_list[local_idx]
                score_list = ret_q_score_dict.setdefault(
                    q_idx, [])  # save all candidates with their scores
                for cand_idx, cand in enumerate(
                        cands):  # enumerate each candidate
                    score = score_vec[cand_idx]
                    sc_vec = sc_mat[cand_idx]
                    att_mat = local_att_tensor[cand_idx]
                    q_weight_vec = local_q_weight[cand_idx]
                    sc_weight_vec = local_sc_weight[cand_idx]
                    cand.run_info = {
                        'score': score,
                        'sc_vec': sc_vec,
                        'att_mat': att_mat,
                        'q_weight_vec': q_weight_vec,
                        'sc_weight_vec': sc_weight_vec
                    }  # save detail information within the cand.
                    score_list.append(cand)

        # After scanning all the batch, now count the final F1 result
        f1_list = []
        for q_idx, score_list in ret_q_score_dict.items():
            score_list.sort(key=lambda x: x.run_info['score'],
                            reverse=True)  # sort by score DESC
            if len(score_list) == 0:
                f1_list.append(0.)
            else:
                f1_list.append(score_list[0].f1
                               )  # pick the f1 of the highest scored schema
        ret_metric = np.mean(f1_list)

        if detail_fp is not None:
            data_loader.dataset.load_dicts()
            w_dict = data_loader.dataset.w_dict
            w_uhash = {v: k for k, v in w_dict.items()}  # word_idx --> word

            bw = open(detail_fp, 'w')
            LogInfo.redirect(bw)
            LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
            srt_q_idx_list = sorted(ret_q_score_dict.keys())
            for q_idx in srt_q_idx_list:
                LogInfo.begin_track(
                    'Q-%04d [%s]:', q_idx,
                    data_loader.dataset.webq_list[q_idx].encode('utf-8'))
                q_words = data_loader.dataset.q_words_dict[q_idx]
                word_surf_list = map(lambda x: w_uhash[x].encode('utf-8'),
                                     q_words)

                srt_list = ret_q_score_dict[q_idx]  # already sorted
                best_label_f1 = np.max([sc.f1 for sc in srt_list])
                best_label_f1 = max(best_label_f1, 0.000001)
                for rank, sc in enumerate(srt_list):
                    path_sz = len(sc.path_list)
                    path_surf_list = map(lambda x: 'Path-%d' % x,
                                         range(path_sz))
                    if rank < 5 or sc.f1 == best_label_f1:
                        LogInfo.begin_track('#-%04d: F1=%.6f, score=%9.6f ',
                                            rank + 1, sc.f1,
                                            sc.run_info['score'])
                        for path_idx, path in enumerate(sc.path_list):
                            LogInfo.logs('Path-%d: [%s]', path_idx,
                                         '-->'.join(path))
                        for path_idx, words in enumerate(sc.path_words_list):
                            LogInfo.logs('Path-Word-%d: [%s]', path_idx,
                                         ' | '.join(words).encode('utf-8'))
                        LogInfo.logs('Raw Attention Score:')
                        self.print_att_matrix(word_surf_list=word_surf_list,
                                              path_surf_list=path_surf_list,
                                              att_mat=sc.run_info['att_mat'])
                        LogInfo.logs('Q-side Attention Weight:')
                        self.print_weight_vec(
                            item_surf_list=word_surf_list,
                            weight_vec=sc.run_info['q_weight_vec'])
                        LogInfo.logs('SC-side Attention Weight:')
                        self.print_weight_vec(
                            item_surf_list=path_surf_list,
                            weight_vec=sc.run_info['sc_weight_vec'])
                LogInfo.end_track()
            LogInfo.logs('Epoch-%d: avg_f1 = %.6f', epoch_idx, ret_metric)
            LogInfo.stop_redirect()
            bw.close()
        return ret_metric
Пример #13
0
def work(data_name, exp_dir_1, data_dir_1, exp_dir_2, data_dir_2,
         out_detail_fp, out_anno_fp):
    qa_list = load_compq()
    detail_fp_1 = exp_dir_1 + '/detail/full.t.best'
    detail_fp_2 = exp_dir_2 + '/detail/full.t.best'
    qidx_meta_dict_1 = read_ours(detail_fp_1)
    qidx_meta_dict_2 = read_ours(detail_fp_2)
    bw_detail = codecs.open(out_detail_fp, 'w', 'utf-8')
    bw_anno = codecs.open(out_anno_fp, 'w', 'utf-8')
    LogInfo.redirect(bw_detail)
    for bw in (bw_detail, bw_anno):
        bw.write('detail_fp_1: [%s] --> [%s]\n' % (data_dir_1, detail_fp_1))
        bw.write('detail_fp_2: [%s] --> [%s]\n\n' % (data_dir_2, detail_fp_2))

    missing_list = []
    first_only_list = []
    second_only_list = []
    compare_list = []
    if data_name == 'WebQ':
        range_list = range(3778, 5810)
    else:
        assert data_name == 'CompQ'
        range_list = range(1300, 2100)
    for q_idx in range_list:
        if q_idx not in qidx_meta_dict_1 and q_idx not in qidx_meta_dict_2:
            missing_list.append(q_idx)
        elif q_idx not in qidx_meta_dict_2:
            first_only_list.append(q_idx)
        elif q_idx not in qidx_meta_dict_1:
            second_only_list.append(q_idx)
        else:
            compare_list.append(q_idx)

    LogInfo.logs('Missing questions: %s', missing_list)
    LogInfo.logs('First only questions: %s', first_only_list)
    LogInfo.logs('Second only questions: %s\n', second_only_list)

    time_f1_list = [[], []]
    nontime_f1_list = [[], []]
    mark_counter = {}
    disc = Discretizer(split_list=[-0.5, -0.1, -0.000001, 0.000001, 0.1, 0.5])
    compare_list.sort(
        key=lambda x: qidx_meta_dict_1[x]['f1'] - qidx_meta_dict_2[x]['f1'])
    for q_idx in compare_list:
        info_dict_1 = qidx_meta_dict_1[q_idx]
        info_dict_2 = qidx_meta_dict_2[q_idx]
        f1_1 = info_dict_1['f1']
        f1_2 = info_dict_2['f1']
        delta = f1_1 - f1_2
        disc.convert(delta)
        qa = qa_list[q_idx]
        LogInfo.logs('============================\n')
        LogInfo.begin_track('Q-%04d: [%s]', q_idx, qa['utterance'])
        LogInfo.logs('f1_1 = %.6f, f1_2 = %.6f, delta = %.6f', f1_1, f1_2,
                     delta)
        upb_list = []
        for d_idx, (data_dir,
                    info_dict) in enumerate([(data_dir_1, info_dict_1),
                                             (data_dir_2, info_dict_2)]):
            LogInfo.begin_track('Schema-%d, line = %d', d_idx,
                                info_dict['line_no'])
            upb = retrieve_schema(data_dir, q_idx, info_dict['line_no'])
            upb_list.append(upb)
            LogInfo.end_track()
        LogInfo.end_track()
        LogInfo.logs('')

        bw_anno.write('Q-%04d: [%s]\n' % (q_idx, qa['utterance']))
        bw_anno.write('f1_1 = %.6f, f1_2 = %.6f, delta = %.6f\n' %
                      (f1_1, f1_2, delta))
        if abs(delta) >= 0.5:
            hml = 'H'
        elif abs(delta) >= 0.1:
            hml = 'M'
        elif abs(delta) >= 1e-6:
            hml = 'L'
        else:
            hml = '0'
        if delta >= 1e-6:
            sgn = '+'
        elif delta <= -1e-6:
            sgn = '-'
        else:
            sgn = ''
        bw_anno.write('# Change: [%s%s]\n' % (sgn, hml))
        has_time = 'N'
        for tok in qa['tokens']:
            if re.match('^[1-2][0-9][0-9][0-9]$', tok.token[:4]):
                has_time = 'Y'
                break
        if has_time == 'Y':
            time_f1_list[0].append(f1_1)
            time_f1_list[1].append(f1_2)
        else:
            nontime_f1_list[0].append(f1_1)
            nontime_f1_list[1].append(f1_2)
        bw_anno.write('# Time: [%s]\n' % has_time)
        upb1, upb2 = upb_list
        if upb1 - upb2 <= -1e-6:
            upb_mark = 'Less'
        elif upb1 - upb2 >= 1e-6:
            upb_mark = 'Greater'
        else:
            upb_mark = 'Equal'
        bw_anno.write('# Upb: [%s] (%.3f --> %.3f)\n' % (upb_mark, upb1, upb2))
        overall = '%s%s_%s_%s' % (sgn, hml, has_time, upb_mark)
        mark_counter[overall] = 1 + mark_counter.get(overall, 0)
        bw_anno.write('# Overall: [%s]\n' % overall)
        bw_anno.write('\n\n')

    disc.show_distribution()

    LogInfo.logs('')
    for has_time in ('Y', 'N'):
        LogInfo.logs('Related to DateTime: [%s]', has_time)
        LogInfo.logs('    \tLess\tEqual\tGreater')
        for hml in ('-H', '-M', '-L', '0', '+L', '+M', '+H'):
            line = '%4s' % hml
            for upb_mark in ('Less', 'Equal', 'Greater'):
                overall = '%s_%s_%s' % (hml, has_time, upb_mark)
                count = mark_counter.get(overall, 0)
                line += '\t%4d' % count
                # LogInfo.logs('[%s]: %d (%.2f%%)', overall, count, 100. * count / 800)
            LogInfo.logs(line)
        LogInfo.logs('')
    LogInfo.logs('DateTime-related F1: %.6f v.s. %.6f, size = %d',
                 np.mean(time_f1_list[0]), np.mean(time_f1_list[1]),
                 len(time_f1_list[0]))
    LogInfo.logs('DateTime-not-related F1: %.6f v.s. %.6f, size = %d',
                 np.mean(nontime_f1_list[0]), np.mean(nontime_f1_list[1]),
                 len(nontime_f1_list[0]))

    LogInfo.stop_redirect()

    bw_detail.close()
    bw_anno.close()
Пример #14
0
def main(args):
    # ==== Optm & Eval register ==== #
    # ltr: learning-to-rank; full: fully-connected layer as the last layer
    full_optm_method = args.full_optm_method
    if full_optm_method in ('el', 'rm'):        # sub-task only mode
        optm_tasks = eval_tasks = [full_optm_method]
    else:                                       # ltr or full
        optm_tasks = ['el', 'rm', 'full']
        eval_tasks = ['el', 'rm', 'full']
        # all sub-tasks needed, including full
        # we need direct comparison between full and ltr
        if full_optm_method == 'ltr':
            eval_tasks.append('ltr')
    LogInfo.logs('full_optm_method: %s', full_optm_method)
    LogInfo.logs('optimize tasks: %s', optm_tasks)
    LogInfo.logs('evaluate tasks: %s', eval_tasks)

    # ==== Loading Necessary Util ==== #
    LogInfo.begin_track('Loading Utils ... ')
    wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb)
    LogInfo.end_track()

    # ==== Loading Dataset ==== #
    LogInfo.begin_track('Creating Dataset ... ')
    data_config = literal_eval(args.data_config)
    data_config['el_feat_size'] = 3
    data_config['extra_feat_size'] = 16
    data_config['wd_emb_util'] = wd_emb_util
    data_config['verbose'] = args.verbose
    schema_dataset = SchemaDatasetDep(**data_config)
    schema_dataset.load_all_data()
    """ load data before constructing model, as we generate lookup dict in the loading phase """
    LogInfo.end_track()

    # ==== Building Model ==== #
    LogInfo.begin_track('Building Model and Session ... ')
    gpu_options = tf.GPUOptions(allow_growth=True,
                                per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            intra_op_parallelism_threads=8))
    model_config = literal_eval(args.model_config)
    for key in ('qw_max_len', 'pw_max_len', 'path_max_size', 'pseq_max_len',
                'el_feat_size', 'extra_feat_size'):
        model_config[key] = getattr(schema_dataset, key)
    model_config['n_words'] = len(schema_dataset.active_dicts['word'])
    model_config['n_paths'] = len(schema_dataset.active_dicts['path'])
    model_config['n_mids'] = len(schema_dataset.active_dicts['mid'])
    model_config['dim_emb'] = wd_emb_util.dim_emb
    full_back_prop = model_config['full_back_prop']
    compq_mt_model = CompqMultiTaskModel(**model_config)

    LogInfo.begin_track('Showing final parameters: ')
    for var in tf.global_variables():
        LogInfo.logs('%s: %s', var.name, var.get_shape().as_list())
    LogInfo.end_track()

    attempt_param_list = ['rm_task/rm_final_merge/sent_repr/out_fc/weights',
                          'rm_task/rm_final_merge/sent_repr/out_fc/biases',
                          'rm_task/rm_final_merge/alpha',
                          'el_task/out_fc/weights', 'el_task/out_fc/biases',
                          'full_task/out_fc/weights', 'full_task/out_fc/biases']
    focus_param_list = []
    focus_param_name_list = []
    with tf.variable_scope('', reuse=tf.AUTO_REUSE):
        for param_name in attempt_param_list:
            try:
                var = tf.get_variable(name=param_name)
                focus_param_list.append(var)
                focus_param_name_list.append(param_name)
            except ValueError:
                pass
    LogInfo.begin_track('Showing %d concern parameters: ', len(focus_param_list))
    for name, tensor in zip(focus_param_name_list, focus_param_list):
        LogInfo.logs('%s --> %s', name, tensor.get_shape().as_list())
    LogInfo.end_track()

    saver = tf.train.Saver()
    LogInfo.begin_track('Running global_variables_initializer ...')
    start_epoch = 0
    best_valid_f1 = 0.
    resume_flag = False
    model_dir = None
    if args.resume_model_name not in ('', 'None'):
        model_dir = '%s/%s' % (args.output_dir, args.resume_model_name)
        if os.path.exists(model_dir):
            resume_flag = True
    if resume_flag:
        start_epoch, best_valid_f1 = load_model(saver=saver, sess=sess, model_dir=model_dir)
    else:
        dep_simulate = True if args.dep_simulate == 'True' else False
        wd_emb_mat = wd_emb_util.produce_active_word_embedding(
            active_word_dict=schema_dataset.active_dicts['word'],
            dep_simulate=dep_simulate
        )
        pa_emb_mat = np.random.uniform(low=-0.1, high=0.1,
                                       size=(model_config['n_paths'], model_config['dim_emb'])).astype('float32')
        mid_emb_mat = np.random.uniform(low=-0.1, high=0.1,
                                        size=(model_config['n_mids'], model_config['dim_emb'])).astype('float32')
        LogInfo.logs('%s random path embedding created.', pa_emb_mat.shape)
        LogInfo.logs('%s random mid embedding created.', mid_emb_mat.shape)
        sess.run(tf.global_variables_initializer(),
                 feed_dict={compq_mt_model.w_embedding_init: wd_emb_mat,
                            compq_mt_model.p_embedding_init: pa_emb_mat,
                            compq_mt_model.m_embedding_init: mid_emb_mat})
    LogInfo.end_track('Start Epoch = %d', start_epoch)
    LogInfo.end_track('Model build complete.')

    # ==== Register optm / eval ==== #
    rm_optimizer = BaseOptimizer(task_name='rm', compq_mt_model=compq_mt_model, sess=sess)
    el_optimizer = BaseOptimizer(task_name='el', compq_mt_model=compq_mt_model, sess=sess)
    full_optimizer = BaseOptimizer(task_name='full', compq_mt_model=compq_mt_model, sess=sess)
    rm_evaluator = BaseEvaluator(task_name='rm', compq_mt_model=compq_mt_model,
                                 sess=sess, detail_disp_func=show_basic_rm_info)
    el_evaluator = BaseEvaluator(task_name='el', compq_mt_model=compq_mt_model,
                                 sess=sess, detail_disp_func=show_el_detail_without_type)
    full_evaluator = BaseEvaluator(task_name='full', compq_mt_model=compq_mt_model,
                                   sess=sess, detail_disp_func=show_basic_full_info)
    LogInfo.logs('Optimizer & Evaluator defined for RM, EL and FULL.')

    # ==== Iteration begins ==== #
    output_dir = args.output_dir
    if not os.path.exists(output_dir + '/detail'):
        os.makedirs(output_dir + '/detail')
    if not os.path.exists(output_dir + '/result'):
        os.makedirs(output_dir + '/result')

    LogInfo.begin_track('Learning start ...')
    patience = args.max_patience

    status_fp = output_dir + '/status.csv'
    disp_header_list = construct_display_header(optm_tasks=optm_tasks, eval_tasks=eval_tasks)
    with open(status_fp, 'a') as bw:
        write_str = ''.join(disp_header_list)
        bw.write(write_str + '\n')

    if full_back_prop:
        LogInfo.logs('full_back_prop = %s, pre_train_steps = %d.', full_back_prop, args.pre_train_steps)
    else:
        LogInfo.logs('no pre-train available.')

    sent_usage = model_config['sent_usage']
    dep_or_cp = 'dep'
    if sent_usage.startswith('cp'):
        dep_or_cp = 'cp'
    dl_builder = DepSchemaDLBuilder(schema_dataset=schema_dataset, compq_mt_model=compq_mt_model,
                                    neg_pick_config=literal_eval(args.neg_pick_config),
                                    parser_port=parser_port_dict[args.machine],
                                    dep_or_cp=dep_or_cp)
    for epoch in range(start_epoch+1, args.max_epoch+1):
        if patience == 0:
            LogInfo.logs('Early stopping at epoch = %d.', epoch)
            break
        update_flag = False
        disp_item_dict = {'Epoch': epoch}

        LogInfo.begin_track('Epoch %d / %d', epoch, args.max_epoch)

        LogInfo.begin_track('Generating schemas ...')
        task_dls_dict = {}
        for task_name in eval_tasks:
            task_dls_dict[task_name] = dl_builder.build_task_dataloaders(
                task_name=task_name,
                optm_batch_size=args.optm_batch_size,
                eval_batch_size=args.eval_batch_size
            )
            # [task_optm_dl, task_eval_train_dl, ...]
        el_dl_list = task_dls_dict.get('el')
        rm_dl_list = task_dls_dict.get('rm')
        full_dl_list = task_dls_dict.get('full')        # these variables could be None
        LogInfo.end_track()

        if not args.test_only:      # won't perform training when just testing
            """ ==== Sub-task optimizing ==== """
            if epoch <= args.pre_train_steps or not full_back_prop:
                # pre-train stage, or always need train & update
                LogInfo.begin_track('Multi-task optimizing ... ')
                optm_schedule_list = []
                if 'el' in optm_tasks:
                    el_optimizer.reset_optm_info()
                    optm_schedule_list += [('el', x) for x in range(el_dl_list[0].n_batch)]
                    LogInfo.logs('[ el]: n_rows = %d, n_batch = %d.', len(el_dl_list[0]), el_dl_list[0].n_batch)
                if 'rm' in optm_tasks:
                    rm_optimizer.reset_optm_info()
                    optm_schedule_list += [('rm', x) for x in range(rm_dl_list[0].n_batch)]
                    LogInfo.logs('[ rm]: n_rows = %d, n_batch = %d.', len(rm_dl_list[0]), rm_dl_list[0].n_batch)
                np.random.shuffle(optm_schedule_list)
                LogInfo.logs('EL & RM task shuffled.')

                for task_name, batch_idx in optm_schedule_list:
                    if task_name == 'el':
                        el_optimizer.optimize(optm_dl=el_dl_list[0], batch_idx=batch_idx)
                    if task_name == 'rm':
                        rm_optimizer.optimize(optm_dl=rm_dl_list[0], batch_idx=batch_idx)

                if 'el' in optm_tasks:
                    LogInfo.logs('[ el] loss = %.6f', el_optimizer.ret_loss)
                    disp_item_dict['el_loss'] = el_optimizer.ret_loss
                if 'rm' in optm_tasks:
                    LogInfo.logs('[ rm] loss = %.6f', rm_optimizer.ret_loss)
                    disp_item_dict['rm_loss'] = rm_optimizer.ret_loss
                LogInfo.end_track()     # End of optm.

        """ ==== Sub-task evluation, if possible ==== """
        if epoch <= args.pre_train_steps or not full_back_prop:
            for task, task_dl_list, evaluator in [
                ('el', el_dl_list, el_evaluator),
                ('rm', rm_dl_list, rm_evaluator)
            ]:
                if task not in eval_tasks:
                    continue
                LogInfo.begin_track('Evaluation for [%s]:', task)
                for mark, eval_dl in zip('Tvt', task_dl_list[1:]):
                    LogInfo.begin_track('Eval-%s ...', mark)
                    disp_key = '%s_%s_F1' % (task, mark)
                    detail_fp = '%s/detail/%s.%s.tmp' % (output_dir, task, mark)    # detail/rm.T.tmp
                    result_fp = '%s/result/%s.%s.%03d' % (output_dir, task, mark, epoch)    # result/rm.T.001
                    disp_item_dict[disp_key] = evaluator.evaluate_all(
                        eval_dl=eval_dl,
                        detail_fp=detail_fp,
                        result_fp=result_fp
                    )
                    LogInfo.end_track()
                LogInfo.end_track()

        """ ==== full optimization & evaluation, also prepare data for ltr ==== """
        if epoch > args.pre_train_steps or not full_back_prop:
            # pyltr_data_list = []  # save T/v/t <q, [cand]> formation for the use of pyltr
            if 'full' in eval_tasks:
                LogInfo.begin_track('Full-task Optm & Eval:')
                if 'full' in optm_tasks and not args.test_only:
                    LogInfo.begin_track('Optimizing ...')
                    LogInfo.logs('[full]: n_rows = %d, n_batch = %d.', len(full_dl_list[0]), full_dl_list[0].n_batch)
                    full_optimizer.optimize_all(optm_dl=full_dl_list[0])  # quickly optimize the full model
                    LogInfo.logs('[full] loss = %.6f', full_optimizer.ret_loss)
                    disp_item_dict['full_loss'] = full_optimizer.ret_loss
                    LogInfo.end_track()
                for mark, eval_dl in zip('Tvt', full_dl_list[1:]):
                    LogInfo.begin_track('Eval-%s ...', mark)
                    disp_key = 'full_%s_F1' % mark
                    detail_fp = '%s/detail/full.%s.tmp' % (output_dir, mark)
                    result_fp = '%s/result/full.%s.%03d' % (output_dir, mark, epoch)    # result/full.T.001
                    disp_item_dict[disp_key] = full_evaluator.evaluate_all(
                        eval_dl=eval_dl,
                        detail_fp=detail_fp,
                        result_fp=result_fp
                    )
                    # pyltr_data_list.append(full_evaluator.ret_q_score_dict)
                    LogInfo.end_track()
                LogInfo.end_track()

        """ ==== LTR optimization & evaluation (non-TF code) ==== """
        # if 'ltr' in eval_tasks:
        #     LogInfo.begin_track('LTR Optm & Eval:')
        #     assert len(pyltr_data_list) == 3
        #     LogInfo.logs('rich_feats_concat collected for all T/v/t schemas.')
        #     LogInfo.begin_track('Ready for ltr running ... ')
        #     ltr_metric_list = ltr_whole_process(pyltr_data_list=pyltr_data_list,
        #                                         eval_dl_list=full_dl_list[1:],
        #                                         output_dir=output_dir)
        #     LogInfo.end_track()
        #     for mark_idx, mark in enumerate(['T', 'v', 't']):
        #         key = 'ltr_%s_F1' % mark
        #         disp_item_dict[key] = ltr_metric_list[mark_idx]
        #     LogInfo.end_track()

        """ Display & save states (results, details, params) """
        validate_focus = '%s_v_F1' % full_optm_method
        if validate_focus in disp_item_dict:
            cur_valid_f1 = disp_item_dict[validate_focus]
            if cur_valid_f1 > best_valid_f1:
                best_valid_f1 = cur_valid_f1
                update_flag = True
                patience = args.max_patience
            else:
                patience -= 1
            LogInfo.logs('Model %s, best %s = %.6f [patience = %d]',
                         'updated' if update_flag else 'stayed',
                         validate_focus, cur_valid_f1, patience)
            disp_item_dict['Status'] = 'UPDATE' if update_flag else str(patience)
        else:
            disp_item_dict['Status'] = '------'

        disp_item_dict['Time'] = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        with open(status_fp, 'a') as bw:
            write_str = ''
            for item_idx, header in enumerate(disp_header_list):
                if header.endswith(' ') or header == '\t':        # just a split
                    write_str += header
                else:
                    val = disp_item_dict.get(header, '--------')
                    if isinstance(val, float):
                        write_str += '%8.6f' % val
                    else:
                        write_str += str(val)
            bw.write(write_str + '\n')

        LogInfo.logs('Output concern parameters ... ')
        param_result_list = sess.run(focus_param_list)  # don't need any feeds, since we focus on parameters
        with open(output_dir + '/detail/param.%03d' % epoch, 'w') as bw:
            LogInfo.redirect(bw)
            np.set_printoptions(threshold=np.nan)
            for param_name, param_result in zip(focus_param_name_list, param_result_list):
                LogInfo.logs('%s: shape = %s ', param_name, param_result.shape)
                LogInfo.logs(param_result)
                LogInfo.logs('============================\n')
            np.set_printoptions()
            LogInfo.stop_redirect()

        if update_flag:     # save the latest details
            for mode in 'Tvt':
                for task in ('rm', 'el', 'full', 'ltr'):
                    src = '%s/detail/%s.%s.tmp' % (output_dir, task, mode)
                    dest = '%s/detail/%s.%s.best' % (output_dir, task, mode)
                    if os.path.isfile(src):
                        shutil.move(src, dest)
            if args.save_best:
                save_best_dir = '%s/model_best' % output_dir
                delete_dir(save_best_dir)
                save_model(saver=saver, sess=sess, model_dir=save_best_dir,
                           epoch=epoch, valid_metric=best_valid_f1)

        LogInfo.end_track()     # end of epoch
    LogInfo.end_track()         # end of learning