def main(_): melt.apps.init() #ev.init() model = getattr(base, FLAGS.model)() model.debug = True melt.eager.restore(model) ids2text.init() vocab = ids2text.vocab # query = '阿里和腾讯谁更流氓' # passage = '腾讯比阿里流氓' # query = 'c罗和梅西谁踢球更好' # passage = '梅西比c罗踢的好' query = '青光眼遗传吗' passage = '青光眼有遗传因素的,所以如果是您的父亲是青光眼的话,那我们在这里就强烈建议您,自己早期到医院里面去做一个筛查,测一下,看看眼,尤其是检查一下视野,然后视网膜的那个情况,都做一个早期的检查。' qids = text2ids(query) qwords = [vocab.key(qid) for qid in qids] print(qids) print(ids2text.ids2text(qids)) pids = text2ids(passage) pwords = [vocab.key(pid) for pid in pids] print(pids) print(ids2text.ids2text(pids)) x = { 'query': [qids], 'passage': [pids], 'type': [0], } logits = model(x)[0] probs = gezi.softmax(logits) print(probs) print(list(zip(CLASSES, [x for x in probs]))) predict = np.argmax(logits, -1) print('predict', predict, CLASSES[predict]) # print words importance scores word_scores_list = model.pooling.word_scores for word_scores in word_scores_list: print(list(zip(pwords, word_scores[0].numpy())))
def main(_): melt.apps.init() #ev.init() model = getattr(base, FLAGS.model)() model.debug = True melt.eager.restore(model) ids2text.init() vocab = ids2text.vocab content = '这是一个很好的餐馆,菜很不好吃,我还想再去' content = '这是一个很差的餐馆,菜很不好吃,我不想再去' content = '这是一个很好的餐馆,菜很好吃,我还想再去' content = '这是一个很好的餐馆,只是菜很难吃,我还想再去' content = '这是一个很好的餐馆,只是菜很不好吃,我还想再去' cids = text2ids(content) words = [vocab.key(cid) for cid in cids] print(cids) print(ids2text.ids2text(cids)) x = {'content': [cids]} logits = model(x)[0] probs = gezi.softmax(logits, 1) print(probs) print(list(zip(ATTRIBUTES, [list(x) for x in probs]))) predicts = np.argmax(logits, -1) - 2 print('predicts ', predicts) print(list(zip(ATTRIBUTES, predicts))) adjusted_predicts = ev.to_predict(logits) print('apredicts', adjusted_predicts) print(list(zip(ATTRIBUTES, adjusted_predicts))) # print words importance scores word_scores_list = model.pooling.word_scores for word_scores in word_scores_list: print(list(zip(words, word_scores[0].numpy())))
def main(_): logging.set_logging_path('./mount/tmp/') vocab_path = os.path.join(os.path.dirname(os.path.dirname(FLAGS.input)), 'vocab.txt') ids2text.init(vocab_path) FLAGS.vocab = './mount/temp/kaggle/toxic/tfrecords/glove/vocab.txt' FLAGS.length_index = 2 #FLAGS.length_index = 1 FLAGS.buckets = '100,400' FLAGS.batch_sizes = '64,64,32' input_ = FLAGS.input if FLAGS.type == 'test': input_ = input_.replace('train', 'test') inputs = gezi.list_files(input_) inputs.sort() if FLAGS.fold is not None: inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)] if FLAGS.type != 'dump': print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr) dataset = Dataset('valid') dataset = dataset.make_batch(FLAGS.batch_size_, inputs) print('dataset', dataset) timer = gezi.Timer('read record') for i, (x, y) in enumerate(dataset): if i % 10 == 1: print(y[0]) print(x['comment'][0]) print(ids2text.ids2text(x['comment'][0], sep='|')) print(x['comment_str'][0]) break else: pass
def main(_): base = FLAGS.base logging.set_logging_path('./mount/tmp/') vocab_path = f'{base}/vocab.txt' ids2text.init(vocab_path) FLAGS.vocab = f'{base}/vocab.txt' # FLAGS.length_index = 2 # FLAGS.buckets = '100,400' # FLAGS.batch_sizes = '64,64,32' input_ = FLAGS.input if FLAGS.type == 'test': input_ = input_.replace('valid', 'test') inputs = gezi.list_files(input_) inputs.sort() if FLAGS.fold is not None: inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)] if FLAGS.type == 'debug': print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr) dataset = Dataset('valid') dataset = dataset.make_batch(FLAGS.batch_size_, inputs) print('dataset', dataset) timer = gezi.Timer('read record') for i, (x, y) in enumerate(dataset): # if i % 10 == 1: # print(x['id']) # print(x['content'][0]) # print(ids2text.ids2text(x['content'][0], sep='|')) # print(x['content']) # print(type(x['id'].numpy()[0]) == bytes) # break x['id'] = gezi.decode(x['id'].numpy()) x['content_str'] = gezi.decode(x['content_str'].numpy()) for j, id in enumerate(x['id']): if id == '573': print(id, x['content_str'][j]) elif FLAGS.type == 'dump': valid_infos = {} test_infos = {} # TODO notice train and valid also share ids.. so valid only save 0 is ok... # 120000 doc but first 15000 train duplicate id with valid so only save valid result for those ids currently inputs = gezi.list_files(f'{base}/train/*record') dataset = Dataset('valid') dataset = dataset.make_batch(1, inputs) deal(dataset, valid_infos) print('after valid', len(valid_infos)) for key in valid_infos: print(valid_infos[key]) print(ids2text.ids2text(valid_infos[key]['content'])) break ofile = f'{base}/info.pkl' with open(ofile, 'wb') as out: pickle.dump(valid_infos, out) del valid_infos inputs = gezi.list_files(f'{base}/test/*record') dataset = Dataset('test') dataset = dataset.make_batch(1, inputs) deal(dataset, test_infos) print('after test', len(test_infos)) ofile = ofile.replace('.pkl', '.test.pkl') with open(ofile, 'wb') as out: pickle.dump(test_infos, out) for key in test_infos: print(test_infos[key]) print(ids2text.ids2text(test_infos[key]['content'])) break elif FLAGS.type == 'show_info': valid_infos = pickle.load(open(f'{base}/info.pkl', 'rb')) lens = [len(valid_infos[key]['content']) for key in valid_infos] unks = [list(valid_infos[key]['content']).count(FLAGS.unk_id) for key in valid_infos] print('num unks per doc:', sum(unks) / len(valid_infos)) print('num doc with unk ratio:', len([x for x in unks if x != 0]) / len(unks)) print('un unk tokens ratio:', sum(unks) / sum(lens)) print('len max:', np.max(lens)) print('len min:', np.min(lens)) print('len mean:', np.mean(lens)) print('num docs:', len(valid_infos)) num_show = 0 for key in valid_infos: if list(valid_infos[key]['content']).count(FLAGS.unk_id) > 0: print(valid_infos[key]) print(ids2text.ids2text(valid_infos[key]['content'])) num_show += 1 if num_show > 5: break del valid_infos print('--------------for test info:') test_infos = pickle.load(open(f'{base}/info.test.pkl', 'rb')) lens = [len(test_infos[key]['content']) for key in test_infos] unks = [list(test_infos[key]['content']).count(FLAGS.unk_id) for key in test_infos] print('num unks per doc:', sum(unks) / len(test_infos)) print('num doc with unk ratio:', len([x for x in unks if x != 0]) / len(test_infos)) print('un unk tokens ratio:', sum(unks) / sum(lens)) print('len max:', np.max(lens)) print('len min:', np.min(lens)) print('len mean:', np.mean(lens)) print('num docs:', len(test_infos)) num_show = 0 for key in test_infos: if list(test_infos[key]['content']).count(FLAGS.unk_id) > 0: print(test_infos[key]) print(ids2text.ids2text(test_infos[key]['content'])) num_show += 1 if num_show > 5: break else: raise ValueError(FLAGS.type)