Пример #1
0
 def __init__(self, params):
     super(WpmTokenizer, self).__init__(params)
     self._wpm_encoder = wpm_encoder.WpmEncoder(params.vocab_filepath)
     p = self.params
     assert p.target_unk_id == self._wpm_encoder.unk_id
     assert p.target_sos_id == self._wpm_encoder.sentence_start_id
     assert p.target_eos_id == self._wpm_encoder.sentence_end_id
Пример #2
0
 def __init__(self, params):
   super().__init__(params)
   p = self.params
   self._wpm_encoder = wpm_encoder.WpmEncoder(p.vocab_filepath, p.merge_prob)
   assert p.target_unk_id == self._wpm_encoder.unk_id
   assert p.target_sos_id == self._wpm_encoder.sentence_start_id
   assert p.target_eos_id == self._wpm_encoder.sentence_end_id
Пример #3
0
def _RunEncoding():
  enc = wpm_encoder.WpmEncoder(FLAGS.wpm_filepath)
  pairs = zip(
      FLAGS.source_filepaths.split(','), FLAGS.target_filepaths.split(','))
  with tf.python_io.TFRecordWriter(FLAGS.output_filepath) as outf:
    n = 0
    for p in pairs:
      with tf.gfile.Open(p[0], 'r') as sourcef:
        with tf.gfile.Open(p[1], 'r') as targetf:
          for textp in zip(sourcef.readlines(), targetf.readlines()):
            n += 1
            if n % 10000 == 0:
              tf.logging.info('Watermark[%d]: %d', FLAGS.shard_id, n)
            if n % FLAGS.num_shards != FLAGS.shard_id:
              continue
            source_text = _Prepropcess(textp[0])
            target_text = _Prepropcess(textp[1])
            # tf.logging.vlog(5, 'Source: %s', source_text)
            # tf.logging.vlog(5, 'Target: %s', target_text)
            ex = _MakeTfExample(enc, _Prepropcess(source_text),
                                _Prepropcess(target_text))
            if not ex:  # Too long.
              continue
            # tf.logging.vlog(5, 'Ex: %s', ex)
            encoded = ex.SerializeToString()
            outf.write(encoded)
Пример #4
0
 def testMergeProb(self):
   voc = self._CreateVocab()
   enc = wpm_encoder.WpmEncoder(voc, merge_prob=0.)
   with tf.Session():
     ids, strs = enc.Encode('Ditto')
     self.assertEqual(u'▁ D i t t o'.encode('utf-8'),
                      tf.strings.reduce_join(strs, separator=' ').eval())
     self.assertEqual(b'Ditto', self._enc.Decode(ids).eval())
Пример #5
0
def _RunEncoding():
    sess = tf.Session()
    enc = wpm_encoder.WpmEncoder(FLAGS.wpm_filepath)
    src_txt_placeholder = tf.placeholder(tf.string, [])
    src_encode_op = enc.Encode(src_txt_placeholder)
    tgt_txt_placeholder = tf.placeholder(tf.string, [])
    tgt_encode_op = enc.Encode(tgt_txt_placeholder)
    pairs = list(
        zip(FLAGS.source_filepaths.split(','),
            FLAGS.target_filepaths.split(',')))
    with tf.python_io.TFRecordWriter(FLAGS.output_filepath) as outf:
        n = 0
        for p in pairs:
            with tf.gfile.Open(p[0], 'r') as sourcef:
                with tf.gfile.Open(p[1], 'r') as targetf:
                    for textp in zip(sourcef.readlines(), targetf.readlines()):
                        n += 1
                        if n % 10000 == 0:
                            tf.logging.info('Watermark[%d]: %d',
                                            FLAGS.shard_id, n)
                        if n % FLAGS.num_shards != FLAGS.shard_id:
                            continue
                        source_text = _Preprocess(textp[0])
                        target_text = _Preprocess(textp[1])
                        # By convention:
                        # * source always ends in </s>, never starts with <s>.
                        # * target never ends in </s>, always starts with <s>.
                        _AssertTextFormat(source_text)
                        _AssertTextFormat(target_text)
                        ((src_i, src_s), (tgt_i, tgt_s)) = sess.run(
                            [src_encode_op, tgt_encode_op],
                            feed_dict={
                                src_txt_placeholder: source_text,
                                tgt_txt_placeholder: target_text
                            },
                        )
                        ex = _MakeTfExample(enc, src_i, src_s, tgt_i, tgt_s)
                        if not ex:  # Too long.
                            continue
                        encoded = ex.SerializeToString()
                        outf.write(encoded)
Пример #6
0
 def setUp(self):
   voc = self._CreateVocab()
   self._enc = wpm_encoder.WpmEncoder(voc)
Пример #7
0
 def testMergeProb(self):
   voc = self._CreateVocab()
   enc = wpm_encoder.WpmEncoder(voc, merge_prob=0.)
   self.assertEqual('D i t t o', enc.Encode('Ditto'))