示例#1
0
def load_dataset(mode,
                 vocab_id_mapping,
                 max_seq_len,
                 sampling=False,
                 with_label=True,
                 label_version=None):
    dataset = dict()
    tid_list = tokenized_to_tid_list(
        load_tokenized_list(data_config.path(mode, TEXT, EK)),
        vocab_id_mapping)
    dataset[TID] = tid_list
    print('{}: {}'.format(mode,
                          max(list(map(lambda _item: len(_item), tid_list)))))

    if with_label:
        label_path = data_config.path(mode, LABEL, label_version)
        label_list = load_label_list(label_path)
        dataset[LABEL_GOLD] = np.asarray(label_list)

    if sampling:
        dataset = custom_sampling(dataset)

    dataset[TID], dataset[SEQ_LEN] = to_nn_input(dataset[TID],
                                                 max_seq_len=max_seq_len)

    if with_label:
        output_dim = max(dataset[LABEL_GOLD]) + 1
        return dataset, output_dim
    else:
        return dataset
def build_text_label():
    for key, func in {TRAIN: Processor.load_origin_train, TEST: Processor.load_origin_test}.items():
        text_path = config.path(key, TEXT)
        label_A_path = config.path(key, LABEL, 'A')
        label_B_path = config.path(key, LABEL, 'B')

        labels_A = list()
        with open(text_path, 'w') as text_obj, open(label_A_path, 'w') as label_A_obj, open(label_B_path, 'w') as label_B_obj:
            for label, text in func('B'):
                text = re.sub('\s+', ' ', text)
                text_obj.write(text + '\n')
                label_B_obj.write(str(label) + '\n')
                label_A_obj.write(str(0 if label == 0 else 1) + '\n')
                labels_A.append(0 if label == 0 else 1)

        mismatch = 0
        for i, res in enumerate(func('A')):
            if not res[0] == labels_A[i]:
                mismatch += 1
        print(key, mismatch)
示例#3
0
def m3(config_path='e83.yaml'):
    """
    [Usage]
    python3 -m algo.ensemble93 main -e mv --build-analysis

    :param config_path:
    :return:
    """
    config_data = yaml.load(open(config_path))
    config = Config(data=config_data)

    for mode in [TEST, ]:
        labels_gold = load_label_list(data_config.path(mode, LABEL, 'B'))

        b_result = combine(output_keys=config.components(), mode=mode)
        b_vote = list(map(lambda _item: _item[0], b_result))

        b0_result = dict()
        b0_vote = dict()

        last_vote = b_vote

        res = basic_evaluate(gold=labels_gold, pred=last_vote)

        print('{}'.format(mode))
        print_evaluation(res)
        for col in res[CONFUSION_MATRIX]:
            print(','.join(map(str, col)))

        for i in [1, 2, 3]:
            key = 'b0{}'.format(i)
            thr = config.thr(key)
            b0_result[i] = combine(output_keys=config.components(key), mode=mode)

            new_vote = list()
            for l_v, b0_res in zip(last_vote, b0_result[i]):
                this_vote = 0 if b0_res[0] == 0 else i
                if l_v in {0, i} and b0_res[1] >= thr:
                    new_vote.append(this_vote)
                else:
                    new_vote.append(l_v)
            last_vote = new_vote

            res = basic_evaluate(gold=labels_gold, pred=new_vote)

            print('{} - {}'.format(mode, i))
            print_evaluation(res)
            for col in res[CONFUSION_MATRIX]:
                print(','.join(map(str, col)))

        open('latest_ef83.label', 'w').write('\n'.join(list(map(str, last_vote))))
示例#4
0
def main(config_path='e83.yaml'):
    """
    [Usage]
    python3 -m algo.ensemble93 main -e mv --build-analysis

    :param config_path:
    :return:
    """
    config_data = yaml.load(open(config_path))
    config = Config(data=config_data)

    for mode in [TRAIN, TEST]:
        b_result = combine(output_keys=config.components('b'), mode=mode)
        b_vote = list(map(lambda _item: _item[0], b_result))

        b2_result = combine(output_keys=config.components('b2'), mode=mode)
        b2_vote = list(map(lambda _item: _item[0], b2_result))

        last_vote = list()
        for b_v, b2_v in zip(b_vote, b2_vote):
            if b_v == 0:
                label = 0
            elif b2_v == 0:
                label = 1
            else:
                label = 2
            last_vote.append(label)

        b3_result = combine(output_keys=config.components('b3'), mode=mode)
        b3_vote = list(map(lambda _item: _item[0], b3_result))

        labels_predict = list()
        for last_v, b3_v in zip(last_vote, b3_vote):
            if last_v != 2:
                label = last_v
            elif b3_v == 0:
                label = 2
            else:
                label = 3
            labels_predict.append(label)

        labels_gold = load_label_list(data_config.path(mode, LABEL, 'B'))

        res = basic_evaluate(gold=labels_gold, pred=labels_predict)

        print(mode)
        print_evaluation(res)
        for col in res[CONFUSION_MATRIX]:
            print(','.join(map(str, col)))
示例#5
0
def m3a(target=0, thr=1, config_path='e83a.yaml'):
    target = int(target)
    thr = int(thr)
    config_data = yaml.load(open(config_path))
    config = Config(data=config_data)

    for mode in [TEST, ]:
        labels_gold = load_label_list(data_config.path(mode, LABEL, 'A'))

        b_result = combine(output_keys=config.components(), mode=mode)
        new_vote = list()
        for r in b_result:
            if r[0] == target and r[1] >= thr:
                new_vote.append(target)
            else:
                new_vote.append(1 - target)
        res = basic_evaluate(gold=labels_gold, pred=new_vote)

        print('{}'.format(mode))
        print_evaluation(res)
        for col in res[CONFUSION_MATRIX]:
            print(','.join(map(str, col)))

        last_vote = new_vote
        output_keys = config.components('b')
        b_result, counts = combine(output_keys=output_keys, mode=mode, full_output=True)
        new_vote = list()
        for count, l_v in zip(counts, last_vote):
            if count[0] <= 1:
                new_vote.append(0)
            else:
                new_vote.append(l_v)
        res = basic_evaluate(gold=labels_gold, pred=new_vote)
        print('{}'.format(mode))
        print_evaluation(res)
        for col in res[CONFUSION_MATRIX]:
            print(','.join(map(str, col)))
示例#6
0
def train(text_version='ek', label_version=None, config_path='c83.yaml'):
    """
    python -m algo.main93_v2 train
    python3 -m algo.main93_v2 train -c config_ntua93.yaml

    :param text_version: string
    :param label_version: string
    :param config_path: string
    :return:
    """
    pos_label = 1 if label_version == 'A' else None

    config_data = yaml.load(open(config_path))

    output_key = '{}_{}_{}'.format(NNModel.name, text_version,
                                   int(time.time()))
    if label_version is not None:
        output_key = '{}_{}'.format(label_version, output_key)
    print('OUTPUT_KEY: {}'.format(output_key))

    # 准备输出路径的文件夹
    data_config.prepare_output_folder(output_key=output_key)
    data_config.prepare_model_folder(output_key=output_key)

    shutil.copy(config_path, data_config.output_path(output_key, ALL, CONFIG))

    w2v_key = '{}_{}'.format(config_data['word']['w2v_version'], text_version)
    w2v_model_path = data_config.path(ALL, WORD2VEC, w2v_key)
    vocab_train_path = data_config.path(TRAIN, VOCAB, text_version)

    # 加载字典集
    # 在模型中会采用所有模型中支持的词向量, 并为有足够出现次数的单词随机生成词向量
    vocab_meta_list = load_vocab_list(vocab_train_path)
    vocabs = [
        _meta['t'] for _meta in vocab_meta_list
        if _meta['tf'] >= config_data['word']['min_tf']
    ]

    # 加载词向量与相关数据
    lookup_table, vocab_id_mapping, embedding_dim = load_lookup_table2(
        w2v_model_path=w2v_model_path, vocabs=vocabs)
    json.dump(
        vocab_id_mapping,
        open(data_config.output_path(output_key, ALL, VOCAB_ID_MAPPING), 'w'))

    # 加载配置
    nn_config = NNConfig(config_data)
    train_config = TrainConfig(config_data['train'])
    early_stop_metric = train_config.early_stop_metric

    # 加载训练数据
    datasets = dict()
    datasets[TRAIN], output_dim = load_dataset(
        mode=TRAIN,
        vocab_id_mapping=vocab_id_mapping,
        max_seq_len=nn_config.seq_len,
        sampling=train_config.train_sampling,
        label_version=label_version)
    datasets[TEST], _ = load_dataset(mode=TEST,
                                     vocab_id_mapping=vocab_id_mapping,
                                     max_seq_len=nn_config.seq_len,
                                     label_version=label_version)

    # 初始化数据集的检索
    index_iterators = {
        TRAIN: IndexIterator.from_dataset(datasets[TRAIN]),
    }
    # 按配置将训练数据切割成训练集和验证集
    index_iterators[TRAIN].split_train_valid(train_config.valid_rate)

    # 计算各个类的权重
    if train_config.use_class_weights:
        label_weight = {
            # 参考 sklearn 中 class_weight='balanced'的公式, 实验显示效果显着
            _label: float(index_iterators[TRAIN].n_sample()) /
            (index_iterators[TRAIN].dim * len(_index))
            for _label, _index in index_iterators[TRAIN].label_index.items()
        }
    else:
        label_weight = {
            _label: 1.
            for _label in range(index_iterators[TRAIN].dim)
        }

    # 基于加载的数据更新配置
    nn_config.set_embedding_dim(embedding_dim)
    nn_config.set_output_dim(output_dim)
    # 搭建神经网络
    nn = NNModel(config=nn_config)
    nn.build_neural_network(lookup_table=lookup_table)

    batch_size = train_config.batch_size
    fetches = {
        mode: {_key: nn.var(_key)
               for _key in fetch_key[mode]}
        for mode in [TRAIN, TEST]
    }

    model_output_prefix = data_config.model_path(key=output_key) + '/model'

    best_res = {mode: None for mode in [TRAIN, VALID]}
    no_update_count = {mode: 0 for mode in [TRAIN, VALID]}
    max_no_update_count = 10

    eval_history = {TRAIN: list(), VALID: list(), TEST: list()}

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        dataset = datasets[TRAIN]
        index_iterator = index_iterators[TRAIN]

        # 训练开始 ##########################################################################
        for epoch in range(train_config.epoch):
            print('== epoch {} = {} ='.format(epoch, output_key))

            # 利用训练集进行训练
            print('TRAIN')
            n_sample = index_iterator.n_sample(TRAIN)
            labels_predict = list()
            labels_gold = list()

            for batch_index in index_iterator.iterate(batch_size,
                                                      mode=TRAIN,
                                                      shuffle=True):
                feed_dict = {
                    nn.var(_key): dataset[_key][batch_index]
                    for _key in feed_key[TRAIN]
                }
                feed_dict[nn.var(SAMPLE_WEIGHTS)] = list(
                    map(label_weight.get, feed_dict[nn.var(LABEL_GOLD)]))
                feed_dict[nn.var(TEST_MODE)] = 0
                res = sess.run(fetches=fetches[TRAIN], feed_dict=feed_dict)

                labels_predict += res[LABEL_PREDICT].tolist()
                labels_gold += dataset[LABEL_GOLD][batch_index].tolist()

            labels_predict, labels_gold = labels_predict[:
                                                         n_sample], labels_gold[:
                                                                                n_sample]
            res = basic_evaluate(gold=labels_gold,
                                 pred=labels_predict,
                                 pos_label=pos_label)
            print_evaluation(res)
            eval_history[TRAIN].append(res)

            global_step = tf.train.global_step(sess, nn.var(GLOBAL_STEP))

            if train_config.valid_rate == 0.:
                if best_res[TRAIN] is None or res[
                        early_stop_metric] > best_res[TRAIN][early_stop_metric]:
                    best_res[TRAIN] = res
                    no_update_count[TRAIN] = 0
                    saver.save(sess,
                               save_path=model_output_prefix,
                               global_step=global_step)
                else:
                    no_update_count[TRAIN] += 1
            else:
                if best_res[TRAIN] is None or res[
                        early_stop_metric] > best_res[TRAIN][early_stop_metric]:
                    best_res[TRAIN] = res
                    no_update_count[TRAIN] = 0
                else:
                    no_update_count[TRAIN] += 1

                # 计算在验证集上的表现, 不更新模型参数
                print('VALID')
                n_sample = index_iterator.n_sample(VALID)
                labels_predict = list()
                labels_gold = list()

                for batch_index in index_iterator.iterate(batch_size,
                                                          mode=VALID,
                                                          shuffle=False):
                    feed_dict = {
                        nn.var(_key): dataset[_key][batch_index]
                        for _key in feed_key[TEST]
                    }
                    feed_dict[nn.var(TEST_MODE)] = 1
                    res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)

                    labels_predict += res[LABEL_PREDICT].tolist()
                    labels_gold += dataset[LABEL_GOLD][batch_index].tolist()

                labels_predict, labels_gold = labels_predict[:
                                                             n_sample], labels_gold[:
                                                                                    n_sample]
                res = basic_evaluate(gold=labels_gold,
                                     pred=labels_predict,
                                     pos_label=pos_label)
                eval_history[VALID].append(res)
                print_evaluation(res)

                # Early Stop
                if best_res[VALID] is None or res[
                        early_stop_metric] > best_res[VALID][early_stop_metric]:
                    saver.save(sess,
                               save_path=model_output_prefix,
                               global_step=global_step)
                    best_res[VALID] = res
                    no_update_count[VALID] = 0
                else:
                    no_update_count[VALID] += 1

            # eval test
            _mode = TEST
            _dataset = datasets[_mode]
            _index_iterator = SimpleIndexIterator.from_dataset(_dataset)
            _n_sample = _index_iterator.n_sample()

            labels_predict = list()
            labels_gold = list()
            for batch_index in _index_iterator.iterate(batch_size,
                                                       shuffle=False):
                feed_dict = {
                    nn.var(_key): _dataset[_key][batch_index]
                    for _key in feed_key[TEST]
                }
                feed_dict[nn.var(TEST_MODE)] = 1
                res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)

                labels_predict += res[LABEL_PREDICT].tolist()
                labels_gold += _dataset[LABEL_GOLD][batch_index].tolist()
            labels_predict, labels_gold = labels_predict[:
                                                         _n_sample], labels_gold[:
                                                                                 _n_sample]
            res = basic_evaluate(gold=labels_gold,
                                 pred=labels_predict,
                                 pos_label=pos_label)
            eval_history[TEST].append(res)
            print('TEST')
            print_evaluation(res)

            if no_update_count[TRAIN] >= max_no_update_count:
                break

                # 训练结束 ##########################################################################
                # 确保输出文件夹存在

    print(
        '========================= BEST ROUND EVALUATION ========================='
    )

    json.dump(eval_history,
              open(data_config.output_path(output_key, 'eval', 'json'), 'w'))

    with tf.Session() as sess:
        prefix_checkpoint = tf.train.latest_checkpoint(
            data_config.model_path(key=output_key))
        saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint))
        saver.restore(sess, prefix_checkpoint)

        nn = BaseNNModel(config=None)
        nn.set_graph(tf.get_default_graph())

        for mode in [TRAIN, TEST]:
            if mode == TRAIN and train_config.train_sampling:
                dataset, _ = load_dataset(mode=TRAIN,
                                          vocab_id_mapping=vocab_id_mapping,
                                          max_seq_len=nn_config.seq_len,
                                          sampling=False,
                                          label_version=label_version)
            else:
                dataset = datasets[mode]
            index_iterator = SimpleIndexIterator.from_dataset(dataset)
            n_sample = index_iterator.n_sample()

            prob_predict = list()
            labels_predict = list()
            labels_gold = list()
            hidden_feats = list()

            for batch_index in index_iterator.iterate(batch_size,
                                                      shuffle=False):
                feed_dict = {
                    nn.var(_key): dataset[_key][batch_index]
                    for _key in feed_key[TEST]
                }
                feed_dict[nn.var(TEST_MODE)] = 1
                res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)
                prob_predict += res[PROB_PREDICT].tolist()
                labels_predict += res[LABEL_PREDICT].tolist()
                hidden_feats += res[HIDDEN_FEAT].tolist()
                if LABEL_GOLD in dataset:
                    labels_gold += dataset[LABEL_GOLD][batch_index].tolist()

            prob_predict = prob_predict[:n_sample]
            labels_predict = labels_predict[:n_sample]
            labels_gold = labels_gold[:n_sample]
            hidden_feats = hidden_feats[:n_sample]

            if mode == TEST:
                res = basic_evaluate(gold=labels_gold,
                                     pred=labels_predict,
                                     pos_label=pos_label)
                best_res[TEST] = res

            # 导出隐藏层
            with open(data_config.output_path(output_key, mode, HIDDEN_FEAT),
                      'w') as file_obj:
                for _feat in hidden_feats:
                    file_obj.write('\t'.join(map(str, _feat)) + '\n')
            # 导出预测的label
            with open(data_config.output_path(output_key, mode, LABEL_PREDICT),
                      'w') as file_obj:
                for _label in labels_predict:
                    file_obj.write('{}\n'.format(_label))
            with open(data_config.output_path(output_key, mode, PROB_PREDICT),
                      'w') as file_obj:
                for _prob in prob_predict:
                    file_obj.write('\t'.join(map(str, _prob)) + '\n')

    for mode in [TRAIN, VALID, TEST]:
        if mode == VALID and train_config.valid_rate == 0.:
            continue
        res = best_res[mode]
        print(mode)
        print_evaluation(res)
        for col in res[CONFUSION_MATRIX]:
            print(','.join(map(str, col)))

        json.dump(
            res,
            open(data_config.output_path(output_key, mode, EVALUATION), 'w'))
        print()

    test_score_list = map(lambda _item: _item['f1'], eval_history[TEST])
    print('best test f1 reached: {}'.format(max(test_score_list)))

    print('OUTPUT_KEY: {}'.format(output_key))
示例#7
0
def train(dataset_key,
          text_version,
          label_version=None,
          config_path='config.yaml'):
    """
    python algo/main.py train semeval2018_task3 -l A -t ek
    python algo/main.py train semeval2018_task3 -l A -t ek -c config_ntua.yaml
    python algo/main.py train semeval2018_task3 -l A -t raw -c config_ntua_char.yaml

    python algo/main.py train semeval2019_task3_dev -t ek

    python algo/main.py train semeval2018_task1 -l love
    python algo/main.py train semeval2014_task9

    :param dataset_key: string
    :param text_version: string
    :param label_version: string
    :param config_path: string
    :return:
    """
    pos_label = None
    if dataset_key == 'semeval2018_task3' and label_version == 'A':
        pos_label = 1

    config_data = yaml.load(open(config_path))

    data_config = getattr(
        importlib.import_module('dataset.{}.config'.format(dataset_key)),
        'config')

    output_key = '{}_{}_{}'.format(config_data['module'].rsplit('.', 1)[1],
                                   text_version, int(time.time()))
    if label_version is not None:
        output_key = '{}_{}'.format(label_version, output_key)
    print('OUTPUT_KEY: {}'.format(output_key))

    # 准备输出路径的文件夹
    data_config.prepare_output_folder(output_key=output_key)
    data_config.prepare_model_folder(output_key=output_key)

    shutil.copy(config_path, data_config.output_path(output_key, ALL, CONFIG))

    # 根据配置加载模块
    module_relative_path = config_data['module']
    NNModel = getattr(importlib.import_module(module_relative_path), 'NNModel')
    NNConfig = getattr(importlib.import_module(module_relative_path),
                       'NNConfig')

    if config_data['analyzer'] == WORD:
        w2v_key = '{}_{}'.format(config_data['word']['w2v_version'],
                                 text_version)
        w2v_model_path = data_config.path(ALL, WORD2VEC, w2v_key)
        vocab_train_path = data_config.path(TRAIN, VOCAB, text_version)

        # 加载字典集
        # 在模型中会采用所有模型中支持的词向量, 并为有足够出现次数的单词随机生成词向量
        vocab_meta_list = load_vocab_list(vocab_train_path)
        vocab_meta_list += load_vocab_list(
            semeval2018_task3_date_config.path(TRAIN, VOCAB, text_version))
        vocabs = [
            _meta['t'] for _meta in vocab_meta_list
            if _meta['tf'] >= config_data[WORD]['min_tf']
        ]

        # 加载词向量与相关数据
        lookup_table, vocab_id_mapping, embedding_dim = load_lookup_table(
            w2v_model_path=w2v_model_path, vocabs=vocabs)
        json.dump(
            vocab_id_mapping,
            open(data_config.output_path(output_key, ALL, VOCAB_ID_MAPPING),
                 'w'))
        max_seq_len = MAX_WORD_SEQ_LEN
    elif config_data['analyzer'] == CHAR:
        texts = load_text_list(data_config.path(TRAIN, TEXT))
        char_set = set()
        for text in texts:
            char_set |= set(text)
        lookup_table, vocab_id_mapping, embedding_dim = build_random_lookup_table(
            vocabs=char_set, dim=config_data['char']['embedding_dim'])
        max_seq_len = MAX_CHAR_SEQ_LEN
    else:
        raise ValueError('invalid analyzer: {}'.format(
            config_data['analyzer']))

    # 加载训练数据
    datasets, output_dim = load_dataset(data_config=data_config,
                                        analyzer=config_data['analyzer'],
                                        vocab_id_mapping=vocab_id_mapping,
                                        seq_len=max_seq_len,
                                        with_label=True,
                                        label_version=label_version,
                                        text_version=text_version)

    # 加载配置
    nn_config = NNConfig(config_data)
    train_config = TrainConfig(config_data['train'])

    # 初始化数据集的检索
    index_iterators = {
        mode: IndexIterator(datasets[mode][LABEL_GOLD])
        for mode in [TRAIN, TEST]
    }
    # 按配置将训练数据切割成训练集和验证集
    index_iterators[TRAIN].split_train_valid(train_config.valid_rate)

    # 计算各个类的权重
    if train_config.use_class_weights:
        label_weight = {
            # 参考 sklearn 中 class_weight='balanced'的公式, 实验显示效果显着
            _label: float(index_iterators[TRAIN].n_sample()) /
            (index_iterators[TRAIN].dim * len(_index))
            for _label, _index in index_iterators[TRAIN].label_index.items()
        }
    else:
        label_weight = {
            _label: 1.
            for _label in range(index_iterators[TRAIN].dim)
        }

    # 基于加载的数据更新配置
    nn_config.set_embedding_dim(embedding_dim)
    nn_config.set_output_dim(output_dim)
    nn_config.set_seq_len(max_seq_len)
    # 搭建神经网络
    nn = NNModel(config=nn_config)
    nn.build_neural_network(lookup_table=lookup_table)

    batch_size = train_config.batch_size
    fetches = {
        mode: {_key: nn.var(_key)
               for _key in fetch_key[mode]}
        for mode in [TRAIN, TEST]
    }
    last_eval = {TRAIN: None, VALID: None, TEST: None}

    model_output_prefix = data_config.model_path(key=output_key) + '/model'

    best_res = {mode: None for mode in [TRAIN, VALID]}
    no_update_count = {mode: 0 for mode in [TRAIN, VALID]}
    max_no_update_count = 10

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        dataset = datasets[TRAIN]
        index_iterator = index_iterators[TRAIN]

        # 训练开始 ##########################################################################
        for epoch in range(train_config.epoch):
            print('== epoch {} =='.format(epoch))

            # 利用训练集进行训练
            print('TRAIN')
            n_sample = index_iterator.n_sample(TRAIN)
            labels_predict = list()
            labels_gold = list()

            for batch_index in index_iterator.iterate(batch_size,
                                                      mode=TRAIN,
                                                      shuffle=True):
                feed_dict = {
                    nn.var(_key): dataset[_key][batch_index]
                    for _key in feed_key[TRAIN]
                }
                feed_dict[nn.var(SAMPLE_WEIGHTS)] = list(
                    map(label_weight.get, feed_dict[nn.var(LABEL_GOLD)]))
                feed_dict[nn.var(TEST_MODE)] = 0
                res = sess.run(fetches=fetches[TRAIN], feed_dict=feed_dict)

                labels_predict += res[LABEL_PREDICT].tolist()
                labels_gold += dataset[LABEL_GOLD][batch_index].tolist()

            labels_predict, labels_gold = labels_predict[:
                                                         n_sample], labels_gold[:
                                                                                n_sample]
            labels_predict, labels_gold = labels_predict[:
                                                         n_sample], labels_gold[:
                                                                                n_sample]
            res = basic_evaluate(gold=labels_gold,
                                 pred=labels_predict,
                                 pos_label=pos_label)
            last_eval[TRAIN] = res
            print_evaluation(res)

            global_step = tf.train.global_step(sess, nn.var(GLOBAL_STEP))

            if train_config.valid_rate == 0.:
                if best_res[TRAIN] is None or res[F1_SCORE] > best_res[TRAIN][
                        F1_SCORE]:
                    best_res[TRAIN] = res
                    no_update_count[TRAIN] = 0
                    saver.save(sess,
                               save_path=model_output_prefix,
                               global_step=global_step)
                else:
                    no_update_count[TRAIN] += 1
            else:
                if best_res[TRAIN] is None or res[F1_SCORE] > best_res[TRAIN][
                        F1_SCORE]:
                    best_res[TRAIN] = res
                    no_update_count[TRAIN] = 0
                else:
                    no_update_count[TRAIN] += 1

                # 计算在验证集上的表现, 不更新模型参数
                print('VALID')
                n_sample = index_iterator.n_sample(VALID)
                labels_predict = list()
                labels_gold = list()

                for batch_index in index_iterator.iterate(batch_size,
                                                          mode=VALID,
                                                          shuffle=False):
                    feed_dict = {
                        nn.var(_key): dataset[_key][batch_index]
                        for _key in feed_key[TEST]
                    }
                    feed_dict[nn.var(TEST_MODE)] = 1
                    res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)
                    labels_predict += res[LABEL_PREDICT].tolist()
                    labels_gold += dataset[LABEL_GOLD][batch_index].tolist()

                labels_predict, labels_gold = labels_predict[:
                                                             n_sample], labels_gold[:
                                                                                    n_sample]
                res = basic_evaluate(gold=labels_gold,
                                     pred=labels_predict,
                                     pos_label=pos_label)
                last_eval[VALID] = res
                print_evaluation(res)

                # Early Stop
                if best_res[VALID] is None or res[F1_SCORE] > best_res[VALID][
                        F1_SCORE]:
                    saver.save(sess,
                               save_path=model_output_prefix,
                               global_step=global_step)
                    best_res[VALID] = res
                    no_update_count[VALID] = 0
                else:
                    no_update_count[VALID] += 1

            if no_update_count[TRAIN] >= max_no_update_count:
                break

        # 训练结束 ##########################################################################
        # 确保输出文件夹存在

    print(
        '========================= BEST ROUND EVALUATION ========================='
    )

    with tf.Session() as sess:
        prefix_checkpoint = tf.train.latest_checkpoint(
            data_config.model_path(key=output_key))
        saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint))
        saver.restore(sess, prefix_checkpoint)

        nn = BaseNNModel(config=None)
        nn.set_graph(tf.get_default_graph())

        for mode in [TRAIN, TEST]:
            dataset = datasets[mode]
            index_iterator = index_iterators[mode]
            n_sample = index_iterator.n_sample()

            prob_predict = list()
            labels_predict = list()
            labels_gold = list()
            hidden_feats = list()

            for batch_index in index_iterator.iterate(batch_size,
                                                      shuffle=False):
                feed_dict = {
                    nn.var(_key): dataset[_key][batch_index]
                    for _key in feed_key[TEST]
                }
                feed_dict[nn.var(TEST_MODE)] = 1
                res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)
                prob_predict += res[PROB_PREDICT].tolist()
                labels_predict += res[LABEL_PREDICT].tolist()
                hidden_feats += res[HIDDEN_FEAT].tolist()
                labels_gold += dataset[LABEL_GOLD][batch_index].tolist()

            prob_predict = prob_predict[:n_sample]
            labels_predict = labels_predict[:n_sample]
            labels_gold = labels_gold[:n_sample]
            hidden_feats = hidden_feats[:n_sample]

            if mode == TEST:
                res = basic_evaluate(gold=labels_gold,
                                     pred=labels_predict,
                                     pos_label=pos_label)
                best_res[TEST] = res

            # 导出隐藏层
            with open(data_config.output_path(output_key, mode, HIDDEN_FEAT),
                      'w') as file_obj:
                for _feat in hidden_feats:
                    file_obj.write('\t'.join(map(str, _feat)) + '\n')
            # 导出预测的label
            with open(data_config.output_path(output_key, mode, LABEL_PREDICT),
                      'w') as file_obj:
                for _label in labels_predict:
                    file_obj.write('{}\n'.format(_label))
            with open(data_config.output_path(output_key, mode, PROB_PREDICT),
                      'w') as file_obj:
                for _prob in prob_predict:
                    file_obj.write('\t'.join(map(str, _prob)) + '\n')

        for mode in [TRAIN, VALID, TEST]:
            if mode == VALID and train_config.valid_rate == 0.:
                continue
            res = best_res[mode]
            print(mode)
            print_evaluation(res)

            json.dump(
                res,
                open(data_config.output_path(output_key, mode, EVALUATION),
                     'w'))
            print()

    print('OUTPUT_KEY: {}'.format(output_key))