def __init__(self, params): super(WpmTokenizer, self).__init__(params) self._wpm_encoder = wpm_encoder.WpmEncoder(params.vocab_filepath) p = self.params assert p.target_unk_id == self._wpm_encoder.unk_id assert p.target_sos_id == self._wpm_encoder.sentence_start_id assert p.target_eos_id == self._wpm_encoder.sentence_end_id
def __init__(self, params): super().__init__(params) p = self.params self._wpm_encoder = wpm_encoder.WpmEncoder(p.vocab_filepath, p.merge_prob) assert p.target_unk_id == self._wpm_encoder.unk_id assert p.target_sos_id == self._wpm_encoder.sentence_start_id assert p.target_eos_id == self._wpm_encoder.sentence_end_id
def _RunEncoding(): enc = wpm_encoder.WpmEncoder(FLAGS.wpm_filepath) pairs = zip( FLAGS.source_filepaths.split(','), FLAGS.target_filepaths.split(',')) with tf.python_io.TFRecordWriter(FLAGS.output_filepath) as outf: n = 0 for p in pairs: with tf.gfile.Open(p[0], 'r') as sourcef: with tf.gfile.Open(p[1], 'r') as targetf: for textp in zip(sourcef.readlines(), targetf.readlines()): n += 1 if n % 10000 == 0: tf.logging.info('Watermark[%d]: %d', FLAGS.shard_id, n) if n % FLAGS.num_shards != FLAGS.shard_id: continue source_text = _Prepropcess(textp[0]) target_text = _Prepropcess(textp[1]) # tf.logging.vlog(5, 'Source: %s', source_text) # tf.logging.vlog(5, 'Target: %s', target_text) ex = _MakeTfExample(enc, _Prepropcess(source_text), _Prepropcess(target_text)) if not ex: # Too long. continue # tf.logging.vlog(5, 'Ex: %s', ex) encoded = ex.SerializeToString() outf.write(encoded)
def testMergeProb(self): voc = self._CreateVocab() enc = wpm_encoder.WpmEncoder(voc, merge_prob=0.) with tf.Session(): ids, strs = enc.Encode('Ditto') self.assertEqual(u'▁ D i t t o'.encode('utf-8'), tf.strings.reduce_join(strs, separator=' ').eval()) self.assertEqual(b'Ditto', self._enc.Decode(ids).eval())
def _RunEncoding(): sess = tf.Session() enc = wpm_encoder.WpmEncoder(FLAGS.wpm_filepath) src_txt_placeholder = tf.placeholder(tf.string, []) src_encode_op = enc.Encode(src_txt_placeholder) tgt_txt_placeholder = tf.placeholder(tf.string, []) tgt_encode_op = enc.Encode(tgt_txt_placeholder) pairs = list( zip(FLAGS.source_filepaths.split(','), FLAGS.target_filepaths.split(','))) with tf.python_io.TFRecordWriter(FLAGS.output_filepath) as outf: n = 0 for p in pairs: with tf.gfile.Open(p[0], 'r') as sourcef: with tf.gfile.Open(p[1], 'r') as targetf: for textp in zip(sourcef.readlines(), targetf.readlines()): n += 1 if n % 10000 == 0: tf.logging.info('Watermark[%d]: %d', FLAGS.shard_id, n) if n % FLAGS.num_shards != FLAGS.shard_id: continue source_text = _Preprocess(textp[0]) target_text = _Preprocess(textp[1]) # By convention: # * source always ends in </s>, never starts with <s>. # * target never ends in </s>, always starts with <s>. _AssertTextFormat(source_text) _AssertTextFormat(target_text) ((src_i, src_s), (tgt_i, tgt_s)) = sess.run( [src_encode_op, tgt_encode_op], feed_dict={ src_txt_placeholder: source_text, tgt_txt_placeholder: target_text }, ) ex = _MakeTfExample(enc, src_i, src_s, tgt_i, tgt_s) if not ex: # Too long. continue encoded = ex.SerializeToString() outf.write(encoded)
def setUp(self): voc = self._CreateVocab() self._enc = wpm_encoder.WpmEncoder(voc)
def testMergeProb(self): voc = self._CreateVocab() enc = wpm_encoder.WpmEncoder(voc, merge_prob=0.) self.assertEqual('D i t t o', enc.Encode('Ditto'))