Пример #1
0
def predict_final(output_key, output_labels):
    config_path = data_config.output_path(output_key, ALL, CONFIG)
    config_data = yaml.load(open(config_path))
    nn_config = NNConfig(config_data)
    vocab_id_mapping = json.load(open(data_config.output_path(output_key, ALL, VOCAB_ID_MAPPING), 'r'))

    dataset = load_dataset(
        mode=FINAL, vocab_id_mapping=vocab_id_mapping,
        max_seq_len=nn_config.seq_len, sampling=False, with_label=False
    )
    index_iterator = SimpleIndexIterator.from_dataset(dataset)
    n_sample = index_iterator.n_sample()

    with tf.Session() as sess:
        prefix_checkpoint = tf.train.latest_checkpoint(data_config.model_path(key=output_key))
        saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint))
        saver.restore(sess, prefix_checkpoint)

        nn = BaseNNModel(config=None)
        nn.set_graph(tf.get_default_graph())

        fetches = {_key: nn.var(_key) for _key in [LABEL_PREDICT]}
        labels_predict = list()

        for batch_index in index_iterator.iterate(nn_config.batch_size, shuffle=False):
            feed_dict = {nn.var(_key): dataset[_key][batch_index] for _key in feed_key[TEST]}
            feed_dict[nn.var(TEST_MODE)] = 1
            res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)
            labels_predict += res[LABEL_PREDICT].tolist()

        labels_predict = labels_predict[:n_sample]

    with open(output_labels, 'w') as file_obj:
        for i, label in enumerate(labels_predict):
            file_obj.write('{},{},{}'.format(i, label, label_str[label]))
Пример #2
0
def live_test(output_key):
    config_path = data_config.output_path(output_key, ALL, CONFIG)
    config_data = yaml.load(open(config_path))
    nn_config = NNConfig(config_data)
    vocab_id_mapping = json.load(open(data_config.output_path(output_key, ALL, VOCAB_ID_MAPPING), 'r'))

    with tf.Session() as sess:
        prefix_checkpoint = tf.train.latest_checkpoint(data_config.model_path(key=output_key))
        saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint))
        saver.restore(sess, prefix_checkpoint)

        nn = BaseNNModel(config=None)
        nn.set_graph(tf.get_default_graph())

        fetches = {_key: nn.var(_key) for _key in [LABEL_PREDICT, PROB_PREDICT]}
        while True:
            res = input('input: ')
            if res == 'quit':
                break

            turns = res.strip().split('|')
            if len(turns) != 3:
                print('invalid turns')
                continue

            tokens_list = list()
            for turn in turns:
                tokens = re.sub('\s+', ' ', turn.strip()).split(' ')
                tokens_list.append(tokens)

            placeholder = [[]] * (nn_config.batch_size - 1)
            tid_list_0 = tokenized_to_tid_list([tokens_list[0], ] + placeholder, vocab_id_mapping)
            tid_list_1 = tokenized_to_tid_list([tokens_list[1], ] + placeholder, vocab_id_mapping)
            tid_list_2 = tokenized_to_tid_list([tokens_list[2], ] + placeholder, vocab_id_mapping)

            tid_0 = np.asarray(zero_pad_seq_list(tid_list_0, nn_config.seq_len))
            tid_1 = np.asarray(zero_pad_seq_list(tid_list_1, nn_config.seq_len))
            tid_2 = np.asarray(zero_pad_seq_list(tid_list_2, nn_config.seq_len))

            feed_dict = {
                nn.var(TID_0): tid_0,
                nn.var(TID_1): tid_1,
                nn.var(TID_2): tid_2,
                nn.var(TEST_MODE): 1
            }
            res = sess.run(fetches=fetches, feed_dict=feed_dict)
            label = res[LABEL_PREDICT][0]
            prob = res[PROB_PREDICT][0]
            print('label: {}'.format(label))
            print('prob: {}'.format(prob))
Пример #3
0
def build_feat(dataset_key_src,
               output_key_src,
               dataset_key_dest='semeval2018_task3',
               text_version=TEXT):
    """
    python algo/main.py feat semeval2014_task9 A_gru_ek_1541081331
    python algo/main.py feat semeval2018_task1 love_gru_1539178720

    :param dataset_key_src: 模型对应的dataset_key
    :param output_key_src: 模型对应的output_key
    :param dataset_key_dest: 需要生成特征向量的数据集对应的dataset_key
    :param text_version:
    :return:
    """
    output_key = '{}.{}'.format(dataset_key_src, output_key_src)
    print('OUTPUT_KEY: {}'.format(output_key))

    # 获取模型文件所在路径
    data_src_config = getattr(
        importlib.import_module('dataset.{}.config'.format(dataset_key_src)),
        'config')
    model_output_prefix = data_src_config.model_path(key=output_key_src)
    # 加载模型对应字典
    vocab_id_mapping = json.load(
        open(
            data_src_config.output_path(output_key_src, ALL, VOCAB_ID_MAPPING),
            'r'))

    # 加载训练数据
    data_config = getattr(
        importlib.import_module('dataset.{}.config'.format(dataset_key_dest)),
        'config')
    data_config.prepare_output_folder(output_key=output_key)
    datasets = load_dataset(data_config=data_config,
                            analyzer=WORD,
                            vocab_id_mapping=vocab_id_mapping,
                            seq_len=MAX_SEQ_LEN,
                            with_label=False,
                            text_version=text_version)
    batch_size = 200

    with tf.Session() as sess:
        prefix_checkpoint = tf.train.latest_checkpoint(model_output_prefix)
        saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint))
        saver.restore(sess, prefix_checkpoint)

        nn = BaseNNModel(config=None)
        nn.set_graph(tf.get_default_graph())
        fetches = {
            mode: {_key: nn.var(_key)
                   for _key in [
                       HIDDEN_FEAT,
                   ]}
            for mode in [
                TEST,
            ]
        }

        for mode in [TRAIN, TEST]:
            dataset = datasets[mode]
            index_iterator = SimpleIndexIterator(
                n_sample=dataset[SEQ_LEN].shape[0])
            n_sample = index_iterator.n_sample()

            hidden_feats = list()
            for batch_index in index_iterator.iterate(batch_size):
                feed_dict = {
                    nn.var(_key): dataset[_key][batch_index]
                    for _key in feed_key[TEST]
                }
                res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict)
                hidden_feats += res[HIDDEN_FEAT].tolist()
            hidden_feats = hidden_feats[:n_sample]

            # 导出隐藏层
            with open(data_config.output_path(output_key, mode, HIDDEN_FEAT),
                      'w') as file_obj:
                for _feat in hidden_feats:
                    file_obj.write('\t'.join(map(str, _feat)) + '\n')

    print('OUTPUT_KEY: {}'.format(output_key))