Esempio n. 1
0
def main(_):
    text2ids.init(FLAGS.vocab_)
    print('to_lower:', FLAGS.to_lower, 'feed_single_en:', FLAGS.feed_single_en,
          'seg_method', FLAGS.seg_method)
    print(text2ids.ids2text(text2ids_('傻逼脑残B')))
    print(text2ids_('傻逼脑残B'))
    print(text2ids.ids2text(text2ids_('喜欢玩孙尚香的加我好友:2948291976')))

    #exit(0)

    if os.path.isfile(FLAGS.input):
        build_features(FLAGS.input)
    else:
        files = glob.glob(FLAGS.input + '/*')
        pool = multiprocessing.Pool(multiprocessing.cpu_count())
        pool.map(build_features, files)
        pool.close()
        pool.join()

    # for safe some machine might not use cpu count as default ...
    print('num_records:', counter.value)
    mode = get_mode(FLAGS.input)

    os.system('mkdir -p %s/%s' % (os.path.dirname(FLAGS.vocab_), mode))
    out_file = os.path.dirname(
        FLAGS.vocab_) + '/{0}/num_records.txt'.format(mode)
    gezi.write_to_txt(counter.value, out_file)

    print('mean words:', total_words.value / counter.value)
def main(_):
    text2ids.init(FLAGS.vocab_)
    print('to_lower:', FLAGS.to_lower, 'feed_single_en:', FLAGS.feed_single_en,
          'seg_method', FLAGS.seg_method)
    print(text2ids.ids2text(text2ids_('傻逼脑残B')))
    print(text2ids.ids2text(text2ids_('喜欢玩孙尚香的加我好友:2948291976')))

    global df
    df = pd.read_csv(FLAGS.input, lineterminator='\n')

    mode = get_mode(FLAGS.input)

    pool = multiprocessing.Pool()

    if mode in ['valid', 'test', 'dev', 'pm']:
        FLAGS.num_records_ = 1

    print('num records file to gen', FLAGS.num_records_)

    #FLAGS.num_records_ = 1

    pool.map(build_features, range(FLAGS.num_records_))
    pool.close()
    pool.join()

    #build_features(FLAGS.input)

    # for safe some machine might not use cpu count as default ...
    print('num_records:', counter.value)

    os.system('mkdir -p %s/%s' % (os.path.dirname(FLAGS.vocab_), mode))
    out_file = os.path.dirname(
        FLAGS.vocab_) + '/{0}/num_records.txt'.format(mode)
    gezi.write_to_txt(counter.value, out_file)

    print('mean words:', total_words.value / counter.value)
def seg(text, out):
    #text = filter.filter(text)
    words = text2ids.ids2words(text2ids_(text))
    words = [x.strip() for x in words if x.strip()]
    if words:
        print(' '.join(words), file=out)
Esempio n. 4
0
def seg(id, text, out):
  text = filter.filter(text)
  _, words = text2ids_(text, return_words=True)
  print(id, '\x09'.join(words), sep='\t', file=out)
Esempio n. 5
0
def build_features(index):
    mode = get_mode(FLAGS.input)

    start_index = FLAGS.start_index

    out_file = os.path.dirname(FLAGS.vocab_) + '/{0}/{1}.record'.format(
        mode, index + start_index)
    os.system('mkdir -p %s' % os.path.dirname(out_file))
    print('---out_file', out_file)
    # TODO now only gen one tfrecord file

    total = len(df)
    num_records = FLAGS.num_records_
    if mode.split('.')[-1] in ['valid', 'test', 'dev', 'pm'
                               ] or 'valid' in FLAGS.input:
        num_records = 1
    start, end = gezi.get_fold(total, num_records, index)

    print('total', total, 'infile', FLAGS.input, 'out_file', out_file)

    max_len = 0
    max_num_ids = 0
    num = 0
    with melt.tfrecords.Writer(out_file) as writer:
        for i in tqdm(range(start, end), ascii=True):
            try:
                row = df.iloc[i]
                id = str(row[0])

                if seg_result:
                    if id not in seg_result:
                        print('id %s ot found in seg_result' % id)
                        continue
                    words = seg_result[id]
                    if FLAGS.add_start_end_:
                        words = gezi.add_start_end(words, FLAGS.start_mark,
                                                   FLAGS.end_mark)
                if pos_result:
                    pos = pos_result[id]
                    if FLAGS.add_start_end_:
                        pos = gezi.add_start_end(pos)
                if ner_result:
                    ner = ner_result[id]
                    if FLAGS.add_start_end_:
                        ner = gezi.add_start_end(ner)

                if start_index > 0:
                    id == 't' + id

                content = row[1]
                content_ori = content
                content = filter.filter(content)

                #label = list(row[2:])
                label = [-2] * 20

                #label = [x + 2 for x in label]
                #num_labels = len(label)

                if not seg_result:
                    content_ids, words = text2ids_(content,
                                                   preprocess=False,
                                                   return_words=True)
                    assert len(content_ids) == len(words)
                else:
                    content_ids = [vocab.id(x) for x in words]
                    #print(words, content_ids)
                    #exit(0)

                if len(content_ids) > max_len:
                    max_len = len(content_ids)
                    print('max_len', max_len)

                if len(content_ids) > FLAGS.word_limit and len(
                        content_ids) < 5:
                    print('{} {} {}'.format(id, len(content_ids), content_ori))
                #if len(content_ids) > FLAGS.word_limit:
                #  print(id, content)
                #  if mode not in ['test', 'valid']:
                #    continue

                #if len(content_ids) < 5 and mode not in ['test', 'valid']:
                #  continue

                content_ids = content_ids[:FLAGS.word_limit]
                words = words[:FLAGS.word_limit]

                # NOTICE different from tf, pytorch do not allow all 0 seq for rnn.. if using padding mode
                if FLAGS.use_char:
                    chars = [list(word) for word in words]
                    char_ids = np.zeros([len(content_ids), FLAGS.char_limit],
                                        dtype=np.int32)

                    vocab_ = char_vocab if char_vocab else vocab

                    for i, token in enumerate(chars):
                        for j, ch in enumerate(token):
                            if j == FLAGS.char_limit:
                                break
                            char_ids[i, j] = vocab_.id(ch)

                    char_ids = list(char_ids.reshape(-1))
                    if np.sum(char_ids) == 0:
                        print('------------------------bad id', id)
                        print(content_ids)
                        print(words)
                        exit(0)
                else:
                    char_ids = [0]

                if pos_vocab:
                    assert pos
                    pos = pos[:FLAGS.word_limit]
                    pos_ids = [pos_vocab.id(x) for x in pos]
                else:
                    pos_ids = [0]

                if ner_vocab:
                    assert ner
                    if pos_vocab:
                        assert len(pos) == len(ner)
                    ner = ner[:FLAGS.word_limit]

                    ner_ids = [ner_vocab.id(x) for x in ner]
                else:
                    ner_ids = [0]

                wlen = [len(word) for word in words]

                feature = {
                    'id': melt.bytes_feature(id),
                    'label': melt.int64_feature(label),
                    'content': melt.int64_feature(content_ids),
                    'content_str': melt.bytes_feature(content_ori),
                    'char': melt.int64_feature(char_ids),
                    'pos': melt.int64_feature(
                        pos_ids),  # might also be postion info for mix seg
                    'ner': melt.int64_feature(ner_ids),
                    'wlen': melt.int64_feature(wlen),
                    'source': melt.bytes_feature(mode),
                }

                # TODO currenlty not get exact info wether show 1 image or 3 ...
                record = tf.train.Example(features=tf.train.Features(
                    feature=feature))

                writer.write(record)
                num += 1
                global counter
                with counter.get_lock():
                    counter.value += 1
                global total_words
                with total_words.get_lock():
                    total_words.value += len(content_ids)
            except Exception:
                print(traceback.format_exc(), file=sys.stderr)
                pass
def build_features(index):
    mode = get_mode(FLAGS.input)

    start_index = 0 if not FLAGS.use_fold else 1
    out_file = os.path.dirname(FLAGS.vocab) + '/{0}/{1}.record'.format(
        mode, index + start_index)
    os.system('mkdir -p %s' % os.path.dirname(out_file))
    print('---out_file', out_file)
    # TODO now only gen one tfrecord file

    total = len(df)
    num_records = FLAGS.num_records_
    if mode in ['valid', 'test', 'dev', 'pm']:
        num_records = 1
    start, end = gezi.get_fold(total, num_records, index)

    print('infile', FLAGS.input, 'out_file', out_file)

    max_len = 0
    max_num_ids = 0
    num = 0
    with melt.tfrecords.Writer(out_file) as writer:
        for i in range(start, end):
            try:
                row = df.iloc[i]
                id = row[0]
                content = row[1]

                #print(content, type(content))
                if len(content) > max_len:
                    max_len = len(content)
                    print('max_len', max_len)

                if len(content) > 3000:
                    print(id, content)
                    if mode not in ['test', 'valid']:
                        continue

                label = list(row[2:])

                #label = [x + 2 for x in label]
                #num_labels = len(label)

                content_ids = text2ids_(content)

                if len(content_ids) < 5 and mode not in ['test', 'valid']:
                    continue

                limit = FLAGS.limit
                if len(content_ids) > max_num_ids:
                    max_num_ids = len(content_ids)
                    print('max_num_ids', max_num_ids)
                content_ids = content_ids[:limit]

                feature = {
                    'id': melt.bytes_feature(str(id)),
                    'label': melt.int64_feature(label),
                    'content': melt.int64_feature(content_ids),
                    'content_str': melt.bytes_feature(content),
                    'sorce': melt.bytes_feature(mode),
                }

                # TODO currenlty not get exact info wether show 1 image or 3 ...
                record = tf.train.Example(features=tf.train.Features(
                    feature=feature))

                if num % 1000 == 0:
                    print(num)

                writer.write(record)
                num += 1
                global counter
                with counter.get_lock():
                    counter.value += 1
                global total_words
                with total_words.get_lock():
                    total_words.value += len(content_ids)
            except Exception:
                #print(traceback.format_exc(), file=sys.stderr)
                pass
Esempio n. 7
0
def build_features(file_):
    mode = get_mode(FLAGS.input)
    out_file = os.path.dirname(FLAGS.vocab_) + '/{0}/{1}_{2}.tfrecord'.format(
        mode, os.path.basename(os.path.dirname(file_)),
        os.path.basename(file_))
    os.system('mkdir -p %s' % os.path.dirname(out_file))
    print('infile', file_, 'out_file', out_file)

    num = 0
    num_whether = 0
    answer_len = 0
    with melt.tfrecords.Writer(out_file) as writer:
        for line in open(file_):
            try:
                m = json.loads(line.rstrip('\n'))
                url = m['url']
                alternatives = m['alternatives']
                query_id = int(m['query_id'])
                passage = m['passage']
                query = m['query']

                # if query_id != 254146:
                #   continue

                if not 'answer' in m:
                    answer = 'unknown'
                else:
                    answer = m['answer']

                # candidates is neg,pos,uncertain
                # type 0 means true or false,  type 1 means wehter
                candidates, type = sort_alternatives(alternatives, query)

                assert candidates is not None

                answer_id = 0
                for i, candiate in enumerate(candidates):
                    if candiate == answer:
                        answer_id = i

                assert candidates is not None
                candidates_str = '|'.join(candidates)

                query_ids = text2ids_(query)
                passage_ids = text2ids_(passage)

                candidate_neg_ids = text2ids_(candidates[0])
                candidate_pos_ids = text2ids_(candidates[1])
                candidate_na_ids = text2ids_('无法确定')

                if len(candidate_pos_ids) > answer_len:
                    answer_len = len(candidate_pos_ids)
                    print(answer_len)
                if len(candidate_neg_ids) > answer_len:
                    answer_len = len(candidate_neg_ids)
                    print(answer_len)

                assert len(query_ids), line
                assert len(passage_ids), line

                limit = FLAGS.limit

                if len(passage_ids) > limit:
                    print('long line', len(passage_ids), query_id)

                query_ids = query_ids[:limit]
                passage_ids = passage_ids[:limit]

                feature = {
                    'id': melt.bytes_feature(str(query_id)),
                    'url': melt.bytes_feature(url),
                    'alternatives': melt.bytes_feature(alternatives),
                    'candidates': melt.bytes_feature(candidates_str),
                    'passage': melt.int64_feature(passage_ids),
                    'passage_str': melt.bytes_feature(passage),
                    'query': melt.int64_feature(query_ids),
                    'query_str': melt.bytes_feature(query),
                    'candidate_neg': melt.int64_feature(candidate_neg_ids),
                    'candidate_pos': melt.int64_feature(candidate_pos_ids),
                    'candidate_na': melt.int64_feature(candidate_na_ids),
                    'answer': melt.int64_feature(answer_id),
                    'answer_str': melt.bytes_feature(answer),
                    'type': melt.int64_feature(type)
                }

                # TODO currenlty not get exact info wether show 1 image or 3 ...
                record = tf.train.Example(features=tf.train.Features(
                    feature=feature))

                #if not candidates:
                if num % 1000 == 0:
                    print(num, query_id, query, type)
                    print(alternatives, candidates)
                    print(answer, answer_id)

                writer.write(record)
                num += 1
                if type:
                    num_whether += 1
                global counter
                with counter.get_lock():
                    counter.value += 1
                global total_words
                with total_words.get_lock():
                    total_words.value += len(passage_ids)
                if FLAGS.max_examples and num >= FLAGS.max_examples:
                    break
            except Exception:
                print(traceback.format_exc(), file=sys.stderr)
                print('-----------', query)
                print(alternatives)

            #break
    print('num_wehter:', num_whether)