def predict_final(output_key, output_labels): config_path = data_config.output_path(output_key, ALL, CONFIG) config_data = yaml.load(open(config_path)) nn_config = NNConfig(config_data) vocab_id_mapping = json.load(open(data_config.output_path(output_key, ALL, VOCAB_ID_MAPPING), 'r')) dataset = load_dataset( mode=FINAL, vocab_id_mapping=vocab_id_mapping, max_seq_len=nn_config.seq_len, sampling=False, with_label=False ) index_iterator = SimpleIndexIterator.from_dataset(dataset) n_sample = index_iterator.n_sample() with tf.Session() as sess: prefix_checkpoint = tf.train.latest_checkpoint(data_config.model_path(key=output_key)) saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint)) saver.restore(sess, prefix_checkpoint) nn = BaseNNModel(config=None) nn.set_graph(tf.get_default_graph()) fetches = {_key: nn.var(_key) for _key in [LABEL_PREDICT]} labels_predict = list() for batch_index in index_iterator.iterate(nn_config.batch_size, shuffle=False): feed_dict = {nn.var(_key): dataset[_key][batch_index] for _key in feed_key[TEST]} feed_dict[nn.var(TEST_MODE)] = 1 res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict) labels_predict += res[LABEL_PREDICT].tolist() labels_predict = labels_predict[:n_sample] with open(output_labels, 'w') as file_obj: for i, label in enumerate(labels_predict): file_obj.write('{},{},{}'.format(i, label, label_str[label]))
def live_test(output_key): config_path = data_config.output_path(output_key, ALL, CONFIG) config_data = yaml.load(open(config_path)) nn_config = NNConfig(config_data) vocab_id_mapping = json.load(open(data_config.output_path(output_key, ALL, VOCAB_ID_MAPPING), 'r')) with tf.Session() as sess: prefix_checkpoint = tf.train.latest_checkpoint(data_config.model_path(key=output_key)) saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint)) saver.restore(sess, prefix_checkpoint) nn = BaseNNModel(config=None) nn.set_graph(tf.get_default_graph()) fetches = {_key: nn.var(_key) for _key in [LABEL_PREDICT, PROB_PREDICT]} while True: res = input('input: ') if res == 'quit': break turns = res.strip().split('|') if len(turns) != 3: print('invalid turns') continue tokens_list = list() for turn in turns: tokens = re.sub('\s+', ' ', turn.strip()).split(' ') tokens_list.append(tokens) placeholder = [[]] * (nn_config.batch_size - 1) tid_list_0 = tokenized_to_tid_list([tokens_list[0], ] + placeholder, vocab_id_mapping) tid_list_1 = tokenized_to_tid_list([tokens_list[1], ] + placeholder, vocab_id_mapping) tid_list_2 = tokenized_to_tid_list([tokens_list[2], ] + placeholder, vocab_id_mapping) tid_0 = np.asarray(zero_pad_seq_list(tid_list_0, nn_config.seq_len)) tid_1 = np.asarray(zero_pad_seq_list(tid_list_1, nn_config.seq_len)) tid_2 = np.asarray(zero_pad_seq_list(tid_list_2, nn_config.seq_len)) feed_dict = { nn.var(TID_0): tid_0, nn.var(TID_1): tid_1, nn.var(TID_2): tid_2, nn.var(TEST_MODE): 1 } res = sess.run(fetches=fetches, feed_dict=feed_dict) label = res[LABEL_PREDICT][0] prob = res[PROB_PREDICT][0] print('label: {}'.format(label)) print('prob: {}'.format(prob))
def build_feat(dataset_key_src, output_key_src, dataset_key_dest='semeval2018_task3', text_version=TEXT): """ python algo/main.py feat semeval2014_task9 A_gru_ek_1541081331 python algo/main.py feat semeval2018_task1 love_gru_1539178720 :param dataset_key_src: 模型对应的dataset_key :param output_key_src: 模型对应的output_key :param dataset_key_dest: 需要生成特征向量的数据集对应的dataset_key :param text_version: :return: """ output_key = '{}.{}'.format(dataset_key_src, output_key_src) print('OUTPUT_KEY: {}'.format(output_key)) # 获取模型文件所在路径 data_src_config = getattr( importlib.import_module('dataset.{}.config'.format(dataset_key_src)), 'config') model_output_prefix = data_src_config.model_path(key=output_key_src) # 加载模型对应字典 vocab_id_mapping = json.load( open( data_src_config.output_path(output_key_src, ALL, VOCAB_ID_MAPPING), 'r')) # 加载训练数据 data_config = getattr( importlib.import_module('dataset.{}.config'.format(dataset_key_dest)), 'config') data_config.prepare_output_folder(output_key=output_key) datasets = load_dataset(data_config=data_config, analyzer=WORD, vocab_id_mapping=vocab_id_mapping, seq_len=MAX_SEQ_LEN, with_label=False, text_version=text_version) batch_size = 200 with tf.Session() as sess: prefix_checkpoint = tf.train.latest_checkpoint(model_output_prefix) saver = tf.train.import_meta_graph('{}.meta'.format(prefix_checkpoint)) saver.restore(sess, prefix_checkpoint) nn = BaseNNModel(config=None) nn.set_graph(tf.get_default_graph()) fetches = { mode: {_key: nn.var(_key) for _key in [ HIDDEN_FEAT, ]} for mode in [ TEST, ] } for mode in [TRAIN, TEST]: dataset = datasets[mode] index_iterator = SimpleIndexIterator( n_sample=dataset[SEQ_LEN].shape[0]) n_sample = index_iterator.n_sample() hidden_feats = list() for batch_index in index_iterator.iterate(batch_size): feed_dict = { nn.var(_key): dataset[_key][batch_index] for _key in feed_key[TEST] } res = sess.run(fetches=fetches[TEST], feed_dict=feed_dict) hidden_feats += res[HIDDEN_FEAT].tolist() hidden_feats = hidden_feats[:n_sample] # 导出隐藏层 with open(data_config.output_path(output_key, mode, HIDDEN_FEAT), 'w') as file_obj: for _feat in hidden_feats: file_obj.write('\t'.join(map(str, _feat)) + '\n') print('OUTPUT_KEY: {}'.format(output_key))