def build_features(file_):
  if not os.path.isfile(file_):
    return 

  file_name = os.path.basename(file_)
  assert os.path.isdir(FLAGS.input)
  mode = 'train' if 'train' in FLAGS.input else 'valid'
  dir_ = os.path.dirname(os.path.dirname(FLAGS.input))
  out_file = os.path.join(dir_ , '{}/{}/{}.record'.format(FLAGS.tfrecord_dir, mode, file_name))
  os.system('mkdir -p %s' % os.path.dirname(out_file))
  
  print('infile', file_, 'out_file', out_file)

  # if os.path.exists(out_file):
  #   return

  max_len = 0
  max_num_ids = 0
  num = 0
  with melt.tfrecords.Writer(out_file) as writer:
    for line in tqdm(open(file_), total=1e6, ascii=True):
      try:
        line = line.rstrip('\n')
        line = filter.filter(line)
        words = line.split(' ')
        words = gezi.add_start_end(words)
        words_list = gezi.break_sentence(words, FLAGS.max_sentence_len)
        for words in words_list:
          content = ' '.join(words)
          content_ids = [vocab.id(x) for x in words]

          if len(content_ids) > max_len:
            max_len = len(content_ids)
            print('max_len', max_len)

          if len(content_ids) > FLAGS.word_limit and len(content_ids) < 5:
            print('{} {} {}'.format(id, len(content_ids), content_ori))

          content_ids = content_ids[:FLAGS.word_limit]
          words = words[:FLAGS.word_limit]

          # NOTICE different from tf, pytorch do not allow all 0 seq for rnn.. if using padding mode
          if FLAGS.use_char:
            chars = [list(word) for word in words]
            char_ids = np.zeros([len(content_ids), FLAGS.char_limit], dtype=np.int32)
            
            vocab_ = char_vocab if char_vocab else vocab

            for i, token in enumerate(chars):
              for j, ch in enumerate(token):
                if j == FLAGS.char_limit:
                  break
                char_ids[i, j] = vocab_.id(ch)

            char_ids = list(char_ids.reshape(-1))
          else:
            char_ids = [0]

          feature = {
                      'content':  melt.int64_feature(content_ids),
                      'content_str': melt.bytes_feature(content), 
                      'char': melt.int64_feature(char_ids),
                      'source': melt.bytes_feature(FLAGS.source), 
                    }

          # TODO currenlty not get exact info wether show 1 image or 3 ...
          record = tf.train.Example(features=tf.train.Features(feature=feature))

          writer.write(record)
          num += 1
          global counter
          with counter.get_lock():
            counter.value += 1
          global total_words
          with total_words.get_lock():
            total_words.value += len(content_ids)
      except Exception:
        print(traceback.format_exc(), file=sys.stderr)
        pass
Ejemplo n.º 2
0
def build_features(index):
    mode = get_mode(FLAGS.input)

    start_index = FLAGS.start_index

    out_file = os.path.dirname(FLAGS.vocab_) + '/{0}/{1}.record'.format(
        mode, index + start_index)
    os.system('mkdir -p %s' % os.path.dirname(out_file))
    print('---out_file', out_file)
    # TODO now only gen one tfrecord file

    total = len(df)
    num_records = FLAGS.num_records_
    ## TODO FIXME whty here still None ? FLAGS.num_records has bee modified before in main as 7 ...
    #print('---------', num_records, FLAGS.num_records_)
    if not num_records:
        if mode.split('.')[-1] in ['valid', 'test', 'dev', 'pm'
                                   ] or 'valid' in FLAGS.input:
            num_records = 1
        else:
            num_records = 1
    #print('------------------', num_records, FLAGS.num_records_)
    start, end = gezi.get_fold(total, num_records, index)

    print('total', total, 'infile', FLAGS.input, 'out_file', out_file,
          'num_records', num_records, 'start', start, 'end', end)

    max_len = 0
    max_num_ids = 0
    num = 0
    with melt.tfrecords.Writer(out_file) as writer:
        for i in tqdm(range(start, end), ascii=True):
            try:
                #row = df.iloc[i]
                row = df[i]
                id = str(row[0])

                words = row[-1].split('\t')

                content = row[2]
                content_ori = content
                content = filter.filter(content)

                label = int(row[1])

                content_ids = [vocab.id(x) for x in words]

                if len(content_ids) > max_len:
                    max_len = len(content_ids)
                    print('max_len', max_len)

                if len(content_ids) > FLAGS.word_limit and len(
                        content_ids) < 5:
                    print('{} {} {}'.format(id, len(content_ids), content_ori))

                content_ids = content_ids[:FLAGS.word_limit]
                words = words[:FLAGS.word_limit]

                # NOTICE different from tf, pytorch do not allow all 0 seq for rnn.. if using padding mode
                if FLAGS.use_char:
                    chars = [list(word) for word in words]
                    char_ids = np.zeros([len(content_ids), FLAGS.char_limit],
                                        dtype=np.int32)

                    vocab_ = char_vocab if char_vocab else vocab

                    for i, token in enumerate(chars):
                        for j, ch in enumerate(token):
                            if j == FLAGS.char_limit:
                                break
                            char_ids[i, j] = vocab_.id(ch)

                    char_ids = list(char_ids.reshape(-1))
                    if np.sum(char_ids) == 0:
                        print('------------------------bad id', id)
                        print(content_ids)
                        print(words)
                        exit(0)
                else:
                    char_ids = [0]

                feature = {
                    'id': melt.bytes_feature(id),
                    'content': melt.int64_feature(content_ids),
                    'content_str': melt.bytes_feature(content_ori),
                    'char': melt.int64_feature(char_ids),
                    'source': melt.bytes_feature(mode),
                }
                feature['label'] = melt.int64_feature(label)

                # TODO currenlty not get exact info wether show 1 image or 3 ...
                record = tf.train.Example(features=tf.train.Features(
                    feature=feature))

                writer.write(record)
                num += 1
                global counter
                with counter.get_lock():
                    counter.value += 1
                global total_words
                with total_words.get_lock():
                    total_words.value += len(content_ids)
            except Exception:
                print(traceback.format_exc(), file=sys.stderr)
                pass
Ejemplo n.º 3
0
def build_features(index):
    mode = get_mode(FLAGS.input)

    start_index = FLAGS.start_index

    out_file = os.path.dirname(FLAGS.vocab_) + '/{0}/{1}.record'.format(
        mode, index + start_index)
    os.system('mkdir -p %s' % os.path.dirname(out_file))
    print('---out_file', out_file)
    # TODO now only gen one tfrecord file

    total = len(df)
    num_records = FLAGS.num_records_
    if mode.split('.')[-1] in ['valid', 'test', 'dev', 'pm'
                               ] or 'valid' in FLAGS.input:
        num_records = 1
    start, end = gezi.get_fold(total, num_records, index)

    print('total', total, 'infile', FLAGS.input, 'out_file', out_file)

    max_len = 0
    max_num_ids = 0
    num = 0
    with melt.tfrecords.Writer(out_file) as writer:
        for i in tqdm(range(start, end), ascii=True):
            try:
                row = df.iloc[i]
                id = str(row[0])

                if seg_result:
                    if id not in seg_result:
                        print('id %s ot found in seg_result' % id)
                        continue
                    words = seg_result[id]
                    if FLAGS.add_start_end_:
                        words = gezi.add_start_end(words, FLAGS.start_mark,
                                                   FLAGS.end_mark)
                if pos_result:
                    pos = pos_result[id]
                    if FLAGS.add_start_end_:
                        pos = gezi.add_start_end(pos)
                if ner_result:
                    ner = ner_result[id]
                    if FLAGS.add_start_end_:
                        ner = gezi.add_start_end(ner)

                if start_index > 0:
                    id == 't' + id

                content = row[1]
                content_ori = content
                content = filter.filter(content)

                #label = list(row[2:])
                label = [-2] * 20

                #label = [x + 2 for x in label]
                #num_labels = len(label)

                if not seg_result:
                    content_ids, words = text2ids_(content,
                                                   preprocess=False,
                                                   return_words=True)
                    assert len(content_ids) == len(words)
                else:
                    content_ids = [vocab.id(x) for x in words]
                    #print(words, content_ids)
                    #exit(0)

                if len(content_ids) > max_len:
                    max_len = len(content_ids)
                    print('max_len', max_len)

                if len(content_ids) > FLAGS.word_limit and len(
                        content_ids) < 5:
                    print('{} {} {}'.format(id, len(content_ids), content_ori))
                #if len(content_ids) > FLAGS.word_limit:
                #  print(id, content)
                #  if mode not in ['test', 'valid']:
                #    continue

                #if len(content_ids) < 5 and mode not in ['test', 'valid']:
                #  continue

                content_ids = content_ids[:FLAGS.word_limit]
                words = words[:FLAGS.word_limit]

                # NOTICE different from tf, pytorch do not allow all 0 seq for rnn.. if using padding mode
                if FLAGS.use_char:
                    chars = [list(word) for word in words]
                    char_ids = np.zeros([len(content_ids), FLAGS.char_limit],
                                        dtype=np.int32)

                    vocab_ = char_vocab if char_vocab else vocab

                    for i, token in enumerate(chars):
                        for j, ch in enumerate(token):
                            if j == FLAGS.char_limit:
                                break
                            char_ids[i, j] = vocab_.id(ch)

                    char_ids = list(char_ids.reshape(-1))
                    if np.sum(char_ids) == 0:
                        print('------------------------bad id', id)
                        print(content_ids)
                        print(words)
                        exit(0)
                else:
                    char_ids = [0]

                if pos_vocab:
                    assert pos
                    pos = pos[:FLAGS.word_limit]
                    pos_ids = [pos_vocab.id(x) for x in pos]
                else:
                    pos_ids = [0]

                if ner_vocab:
                    assert ner
                    if pos_vocab:
                        assert len(pos) == len(ner)
                    ner = ner[:FLAGS.word_limit]

                    ner_ids = [ner_vocab.id(x) for x in ner]
                else:
                    ner_ids = [0]

                wlen = [len(word) for word in words]

                feature = {
                    'id': melt.bytes_feature(id),
                    'label': melt.int64_feature(label),
                    'content': melt.int64_feature(content_ids),
                    'content_str': melt.bytes_feature(content_ori),
                    'char': melt.int64_feature(char_ids),
                    'pos': melt.int64_feature(
                        pos_ids),  # might also be postion info for mix seg
                    'ner': melt.int64_feature(ner_ids),
                    'wlen': melt.int64_feature(wlen),
                    'source': melt.bytes_feature(mode),
                }

                # TODO currenlty not get exact info wether show 1 image or 3 ...
                record = tf.train.Example(features=tf.train.Features(
                    feature=feature))

                writer.write(record)
                num += 1
                global counter
                with counter.get_lock():
                    counter.value += 1
                global total_words
                with total_words.get_lock():
                    total_words.value += len(content_ids)
            except Exception:
                print(traceback.format_exc(), file=sys.stderr)
                pass