def test_get_tokenizer(self): tk = util.get_tokenizer(self.vocab) self.assertEqual('<pad><EOS>', tk.decode([util.PAD_ID, util.EOS_ID])) self.assertEqual(32583, tk.vocab_size) text = 'Tokenize this!' enc = tk.encode(text) self.assertEqual([15745, 8579, 2131, 61, 32582, 11], enc) # It's invertible! self.assertEqual(text, tk.decode(enc))
def test_get_tokenizer_with_special(self): tk_original = util.get_tokenizer(self.vocab) extra_tokens = ['<SPECIAL1>', '<SPECIAL2>'] tk_with_special, sids = util.get_tokenizer_with_special( self.vocab, extra_tokens) o_size = tk_original.vocab_size self.assertEqual(o_size + 2, tk_with_special.vocab_size) self.assertEqual(['<SPECIAL1>_', '<SPECIAL2>_'], tk_with_special.decode_list([o_size, o_size + 1])) self.assertEqual(o_size, sids['<SPECIAL1>']) self.assertEqual(o_size + 1, sids['<SPECIAL2>'])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') paths = [ os.path.join(FLAGS.raw_dir, x) for x in [ 'ROCStories__spring2016 - ROCStories_spring2016.csv', 'ROCStories_winter2017 - ROCStories_winter2017.csv' ] ] assert paths, FLAGS.raw_dir logging.info('Reading from: %s', paths) logging.info('Loading vocabulary file from %s', FLAGS.vocab_file) tk = util.get_tokenizer(FLAGS.vocab_file) assert tk writers = data_util.get_filewriters(FLAGS.output_base, 'all', FLAGS.num_shards) sharder = data_util.get_text_sharder(FLAGS.num_shards) count = 0 for p in paths: logging.info('Opening %s', p) with tf.gfile.Open(p) as f: reader = csv.reader(f) next(reader) # Header for r in reader: assert len(r) == 7 storyid = r[0] storytitle = r[1] sentences = r[2:7] context_features = tf.train.Features( feature={ 'storyid': data_util.tf_bytes_feature(storyid), 'storytitle': data_util.tf_bytes_feature(storytitle), }) seq_ex = data_util.sents2seqex( sentences, tk, context_features=context_features, add_eos=False, add_untokenized=True) writers[sharder(storyid)].write(seq_ex.SerializeToString()) count += 1 data_util.close_writers(writers) logging.info('Wrote %d records to %d shards.', count, FLAGS.num_shards)