コード例 #1
0
def main(_):
    melt.apps.init()

    #ev.init()

    model = getattr(base, FLAGS.model)()
    model.debug = True

    melt.eager.restore(model)

    ids2text.init()
    vocab = ids2text.vocab

    # query = '阿里和腾讯谁更流氓'
    # passage = '腾讯比阿里流氓'

    # query = 'c罗和梅西谁踢球更好'
    # passage = '梅西比c罗踢的好'
    query = '青光眼遗传吗'
    passage = '青光眼有遗传因素的,所以如果是您的父亲是青光眼的话,那我们在这里就强烈建议您,自己早期到医院里面去做一个筛查,测一下,看看眼,尤其是检查一下视野,然后视网膜的那个情况,都做一个早期的检查。'

    qids = text2ids(query)
    qwords = [vocab.key(qid) for qid in qids]
    print(qids)
    print(ids2text.ids2text(qids))
    pids = text2ids(passage)
    pwords = [vocab.key(pid) for pid in pids]
    print(pids)
    print(ids2text.ids2text(pids))

    x = {
        'query': [qids],
        'passage': [pids],
        'type': [0],
    }

    logits = model(x)[0]
    probs = gezi.softmax(logits)
    print(probs)
    print(list(zip(CLASSES, [x for x in probs])))

    predict = np.argmax(logits, -1)
    print('predict', predict, CLASSES[predict])

    # print words importance scores
    word_scores_list = model.pooling.word_scores

    for word_scores in word_scores_list:
        print(list(zip(pwords, word_scores[0].numpy())))
コード例 #2
0
def main(_):
    melt.apps.init()

    #ev.init()

    model = getattr(base, FLAGS.model)()
    model.debug = True

    melt.eager.restore(model)

    ids2text.init()
    vocab = ids2text.vocab

    content = '这是一个很好的餐馆,菜很不好吃,我还想再去'
    content = '这是一个很差的餐馆,菜很不好吃,我不想再去'
    content = '这是一个很好的餐馆,菜很好吃,我还想再去'
    content = '这是一个很好的餐馆,只是菜很难吃,我还想再去'
    content = '这是一个很好的餐馆,只是菜很不好吃,我还想再去'

    cids = text2ids(content)
    words = [vocab.key(cid) for cid in cids]
    print(cids)
    print(ids2text.ids2text(cids))
    x = {'content': [cids]}
    logits = model(x)[0]
    probs = gezi.softmax(logits, 1)
    print(probs)
    print(list(zip(ATTRIBUTES, [list(x) for x in probs])))

    predicts = np.argmax(logits, -1) - 2
    print('predicts ', predicts)
    print(list(zip(ATTRIBUTES, predicts)))
    adjusted_predicts = ev.to_predict(logits)
    print('apredicts', adjusted_predicts)
    print(list(zip(ATTRIBUTES, adjusted_predicts)))

    # print words importance scores
    word_scores_list = model.pooling.word_scores

    for word_scores in word_scores_list:
        print(list(zip(words, word_scores[0].numpy())))
コード例 #3
0
def main(_):
  logging.set_logging_path('./mount/tmp/')
  vocab_path = os.path.join(os.path.dirname(os.path.dirname(FLAGS.input)), 'vocab.txt')
  ids2text.init(vocab_path)
  FLAGS.vocab = './mount/temp/kaggle/toxic/tfrecords/glove/vocab.txt'

  FLAGS.length_index = 2
  #FLAGS.length_index = 1
  FLAGS.buckets = '100,400'
  FLAGS.batch_sizes = '64,64,32'

  input_ = FLAGS.input 
  if FLAGS.type == 'test':
    input_ = input_.replace('train', 'test')

  inputs = gezi.list_files(input_)
  inputs.sort()
  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  if FLAGS.type != 'dump':
    print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr)

    dataset = Dataset('valid')
    dataset = dataset.make_batch(FLAGS.batch_size_, inputs)

    print('dataset', dataset)

    timer = gezi.Timer('read record')
    for i, (x, y) in enumerate(dataset):
      if i % 10 == 1:
        print(y[0])
        print(x['comment'][0])
        print(ids2text.ids2text(x['comment'][0], sep='|'))
        print(x['comment_str'][0])
        break
  else:
    pass
コード例 #4
0
ファイル: read-records.py プロジェクト: zjatc/wenzheng
def main(_):

  base = FLAGS.base
  logging.set_logging_path('./mount/tmp/')
  vocab_path = f'{base}/vocab.txt'
  ids2text.init(vocab_path)
  FLAGS.vocab = f'{base}/vocab.txt'

  # FLAGS.length_index = 2
  # FLAGS.buckets = '100,400'
  # FLAGS.batch_sizes = '64,64,32'

  input_ = FLAGS.input 
  if FLAGS.type == 'test':
    input_ = input_.replace('valid', 'test')

  inputs = gezi.list_files(input_)
  inputs.sort()
  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  if FLAGS.type == 'debug':
    print('type', FLAGS.type, 'inputs', inputs, file=sys.stderr)

    dataset = Dataset('valid')
    dataset = dataset.make_batch(FLAGS.batch_size_, inputs)

    print('dataset', dataset)

    timer = gezi.Timer('read record')
    for i, (x, y) in enumerate(dataset):
      # if i % 10 == 1:
      #   print(x['id'])
      #   print(x['content'][0])
      #   print(ids2text.ids2text(x['content'][0], sep='|'))
      #   print(x['content'])
      #   print(type(x['id'].numpy()[0]) == bytes)
      #   break
      x['id'] = gezi.decode(x['id'].numpy())
      x['content_str'] = gezi.decode(x['content_str'].numpy())
      for j, id in enumerate(x['id']):
        if id == '573':
          print(id, x['content_str'][j])
  elif FLAGS.type == 'dump':
    valid_infos = {}
    test_infos = {}
    # TODO notice train and valid also share ids.. so valid only save 0 is ok...
    # 120000 doc but first 15000 train duplicate id with valid so only save valid result for those ids currently
    inputs = gezi.list_files(f'{base}/train/*record')
    dataset = Dataset('valid')
    dataset = dataset.make_batch(1, inputs)
    deal(dataset, valid_infos)
    print('after valid', len(valid_infos))

    for key in valid_infos:
      print(valid_infos[key])
      print(ids2text.ids2text(valid_infos[key]['content']))
      break

    ofile = f'{base}/info.pkl'
    with open(ofile, 'wb') as out:
      pickle.dump(valid_infos, out)  

    del valid_infos

    inputs = gezi.list_files(f'{base}/test/*record')
    dataset = Dataset('test')
    dataset = dataset.make_batch(1, inputs)
    deal(dataset, test_infos)
    print('after test', len(test_infos))

    ofile = ofile.replace('.pkl', '.test.pkl')  
    with open(ofile, 'wb') as out:
      pickle.dump(test_infos, out)
    for key in test_infos:
      print(test_infos[key])
      print(ids2text.ids2text(test_infos[key]['content']))
      break
  elif FLAGS.type == 'show_info':
    valid_infos = pickle.load(open(f'{base}/info.pkl', 'rb'))
    lens = [len(valid_infos[key]['content']) for key in valid_infos]
    unks = [list(valid_infos[key]['content']).count(FLAGS.unk_id) for key in valid_infos]
    print('num unks per doc:', sum(unks) / len(valid_infos))
    print('num doc with unk ratio:', len([x for x in unks if x != 0]) / len(unks)) 
    print('un unk tokens ratio:', sum(unks) / sum(lens))
    print('len max:', np.max(lens))
    print('len min:', np.min(lens))
    print('len mean:', np.mean(lens)) 
    print('num docs:', len(valid_infos))

    num_show = 0
    for key in valid_infos:
      if list(valid_infos[key]['content']).count(FLAGS.unk_id) > 0:
        print(valid_infos[key])
        print(ids2text.ids2text(valid_infos[key]['content']))
        num_show += 1
        if num_show > 5:
          break

    del valid_infos
    print('--------------for test info:')
    test_infos = pickle.load(open(f'{base}/info.test.pkl', 'rb'))
    lens = [len(test_infos[key]['content']) for key in test_infos]
    unks = [list(test_infos[key]['content']).count(FLAGS.unk_id) for key in test_infos]
    print('num unks per doc:', sum(unks) / len(test_infos))
    print('num doc with unk ratio:', len([x for x in unks if x != 0]) / len(test_infos)) 
    print('un unk tokens ratio:', sum(unks) / sum(lens))
    print('len max:', np.max(lens))
    print('len min:', np.min(lens))
    print('len mean:', np.mean(lens))
    print('num docs:', len(test_infos))

    num_show = 0
    for key in test_infos:
      if list(test_infos[key]['content']).count(FLAGS.unk_id) > 0:
        print(test_infos[key])
        print(ids2text.ids2text(test_infos[key]['content']))
        num_show += 1
        if num_show > 5:
          break
  else:
    raise ValueError(FLAGS.type)