예제 #1
0
 def test_get_tokenizer(self):
   tk = util.get_tokenizer(self.vocab)
   self.assertEqual('<pad><EOS>', tk.decode([util.PAD_ID, util.EOS_ID]))
   self.assertEqual(32583, tk.vocab_size)
   text = 'Tokenize this!'
   enc = tk.encode(text)
   self.assertEqual([15745, 8579, 2131, 61, 32582, 11], enc)
   # It's invertible!
   self.assertEqual(text, tk.decode(enc))
예제 #2
0
 def test_get_tokenizer_with_special(self):
     tk_original = util.get_tokenizer(self.vocab)
     extra_tokens = ['<SPECIAL1>', '<SPECIAL2>']
     tk_with_special, sids = util.get_tokenizer_with_special(
         self.vocab, extra_tokens)
     o_size = tk_original.vocab_size
     self.assertEqual(o_size + 2, tk_with_special.vocab_size)
     self.assertEqual(['<SPECIAL1>_', '<SPECIAL2>_'],
                      tk_with_special.decode_list([o_size, o_size + 1]))
     self.assertEqual(o_size, sids['<SPECIAL1>'])
     self.assertEqual(o_size + 1, sids['<SPECIAL2>'])
예제 #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    paths = [
        os.path.join(FLAGS.raw_dir, x) for x in [
            'ROCStories__spring2016 - ROCStories_spring2016.csv',
            'ROCStories_winter2017 - ROCStories_winter2017.csv'
        ]
    ]

    assert paths, FLAGS.raw_dir
    logging.info('Reading from: %s', paths)
    logging.info('Loading vocabulary file from %s', FLAGS.vocab_file)
    tk = util.get_tokenizer(FLAGS.vocab_file)
    assert tk

    writers = data_util.get_filewriters(FLAGS.output_base, 'all',
                                        FLAGS.num_shards)
    sharder = data_util.get_text_sharder(FLAGS.num_shards)
    count = 0

    for p in paths:
        logging.info('Opening %s', p)
        with tf.gfile.Open(p) as f:
            reader = csv.reader(f)
            next(reader)  # Header
            for r in reader:
                assert len(r) == 7
                storyid = r[0]
                storytitle = r[1]
                sentences = r[2:7]
                context_features = tf.train.Features(
                    feature={
                        'storyid': data_util.tf_bytes_feature(storyid),
                        'storytitle': data_util.tf_bytes_feature(storytitle),
                    })
                seq_ex = data_util.sents2seqex(
                    sentences,
                    tk,
                    context_features=context_features,
                    add_eos=False,
                    add_untokenized=True)
                writers[sharder(storyid)].write(seq_ex.SerializeToString())
                count += 1
    data_util.close_writers(writers)
    logging.info('Wrote %d records to %d shards.', count, FLAGS.num_shards)