示例#1
0
 def setUp(self):
   super(KittiDataTest, self).setUp()
   self._label_file = test_helper.test_src_dir_path(
       'tasks/car/testdata/kitti_raw_label_testdata.txt')
   self._calib_file = test_helper.test_src_dir_path(
       'tasks/car/testdata/kitti_raw_calib_testdata.txt')
示例#2
0
 def _HistFile(self):
     return test_helper.test_src_dir_path('core/ops/testdata/history.txt')
示例#3
0
 def _Params(self):
     p = tokenizers.SentencePieceTokenizer.Params()
     p.spm_model = test_helper.test_src_dir_path(
         'core/testdata/en-1k.spm.model')
     p.vocab_size = 1024
     return p
示例#4
0
 def _TfEventFile(self):
   return test_helper.test_src_dir_path(
       'core/ops/testdata/events.out.tfevents.test')
示例#5
0
 def _BleuFile(self):
   return test_helper.test_src_dir_path('core/ops/testdata/history_bleu.txt')
示例#6
0
    def testTextPackedInputTextPacking(self):
        p = input_generator.TextPackedInput.Params()
        p.flush_every_n = 0
        p.require_sequential_order = True
        p.file_pattern = 'text:' + test_helper.test_src_dir_path(
            'tasks/mt/testdata/en_de.text')
        p.tokenizer = tokenizers.WpmTokenizer.Params().Set(
            vocab_filepath=test_helper.test_src_dir_path(
                'tasks/mt/wpm-ende-2k.voc'),
            vocab_size=2000)
        # We repeat the 2-line file twice for a batch of 2, each packing both lines.
        p.repeat_count = 2
        p.source_max_length = 16
        p.target_max_length = 20
        p.bucket_batch_limit = [2]
        p.packing_factor = 2
        with self.session() as sess:
            inp = p.Instantiate()
            batch_tensor = inp.GetPreprocessedInputBatch()
            batch, num_examples = sess.run(
                [batch_tensor, inp.GlobalBatchSize()])
            self.assertEqual(num_examples, 4)
        self.assertAllEqual(
            batch.src.ids,
            np.array([
                [
                    109, 251, 98, 595, 1009, 245, 326, 129, 4, 2, 115, 276, 18,
                    66, 2, 0
                ],
                [
                    115, 276, 18, 66, 2, 109, 251, 98, 595, 1009, 245, 326,
                    129, 4, 2, 0
                ],
            ]))
        self.assertAllEqual(
            batch.src.segment_ids,
            np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0],
                      [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0]],
                     dtype=np.float32))
        self.assertAllEqual(
            batch.src.segment_pos,
            np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0],
                      [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]))
        self.assertAllEqual(
            batch.src.strs,
            np.array([
                b'Too much has changed.\tHello!',
                b'Hello!\tToo much has changed.'
            ]))

        self.assertAllEqual(
            batch.tgt.ids,
            np.array([
                [
                    1, 197, 446, 458, 419, 284, 323, 1411, 571, 456, 409, 13,
                    4, 1, 115, 281, 18, 66, 0, 0
                ],
                [
                    1, 115, 281, 18, 66, 1, 197, 446, 458, 419, 284, 323, 1411,
                    571, 456, 409, 13, 4, 0, 0
                ],
            ]))
        self.assertAllEqual(
            batch.tgt.labels,
            np.array([
                [
                    197, 446, 458, 419, 284, 323, 1411, 571, 456, 409, 13, 4,
                    2, 115, 281, 18, 66, 2, 0, 0
                ],
                [
                    115, 281, 18, 66, 2, 197, 446, 458, 419, 284, 323, 1411,
                    571, 456, 409, 13, 4, 2, 0, 0
                ],
            ]))
        self.assertAllEqual(
            batch.tgt.segment_ids,
            np.array(
                [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0],
                 [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0]],
                dtype=np.float32))
        self.assertAllEqual(
            batch.tgt.segment_pos,
            np.array([[
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 1, 2, 3, 4, 0, 0
            ], [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0,
                0]]))
        self.assertAllEqual(
            batch.tgt.strs,
            np.array([
                b'Daf\xc3\xbcr hat sich zu viel ver\xc3\xa4ndert.\tHallo!',
                b'Hallo!\tDaf\xc3\xbcr hat sich zu viel ver\xc3\xa4ndert.'
            ]))
示例#7
0
 def testTextPackedInputProto(self):
     p = input_generator.TextPackedInput.Params()
     p.flush_every_n = 0
     p.require_sequential_order = True
     p.repeat_count = 1
     p.file_pattern = 'tfrecord:' + test_helper.test_src_dir_path(
         'tasks/mt/testdata/en_fr.tfrecord')
     p.pad_to_max_seq_length = True
     p.tokenizer = tokenizers.AsciiTokenizer.Params()
     p.input_file_type = 'sentence_proto'
     p.source_max_length = 22
     p.target_max_length = 24
     p.bucket_batch_limit = [2]
     with self.session() as sess:
         inp = p.Instantiate()
         batch_tensor = inp.GetPreprocessedInputBatch()
         for k, x in batch_tensor.FlattenItems():
             self.assertTrue(x.shape.is_fully_defined(), k)
         batch, num_examples = sess.run(
             [batch_tensor, inp.GlobalBatchSize()])
     self.assertEqual(num_examples, 2)
     self.assertEqual(len(batch.src), 7)
     self.assertAllEqual(batch.src.strs,
                         [b'I love paragliding!', b'vol biv paragliding'])
     self.assertAllEqual(batch.tgt.strs,
                         [b"J'adore le parapente!", b'vol biv parapente'])
     self.assertAllEqual(
         batch.src.ids,
         np.array([
             [
                 13, 3, 16, 19, 26, 9, 3, 20, 5, 22, 5, 11, 16, 13, 8, 13,
                 18, 11, 35, 2, 0, 0
             ],
             [
                 26, 19, 16, 3, 6, 13, 26, 3, 20, 5, 22, 5, 11, 16, 13, 8,
                 13, 18, 11, 2, 0, 0
             ],
         ]))
     self.assertAllEqual(
         batch.tgt.ids,
         np.array([
             [
                 1, 14, 32, 5, 8, 19, 22, 9, 3, 16, 9, 3, 20, 5, 22, 5, 20,
                 9, 18, 24, 9, 35, 0, 0
             ],
             [
                 1, 26, 19, 16, 3, 6, 13, 26, 3, 20, 5, 22, 5, 20, 9, 18,
                 24, 9, 0, 0, 0, 0, 0, 0
             ],
         ]))
     self.assertAllEqual(
         batch.tgt.labels,
         np.array([
             [
                 14, 32, 5, 8, 19, 22, 9, 3, 16, 9, 3, 20, 5, 22, 5, 20, 9,
                 18, 24, 9, 35, 2, 0, 0
             ],
             [
                 26, 19, 16, 3, 6, 13, 26, 3, 20, 5, 22, 5, 20, 9, 18, 24,
                 9, 2, 0, 0, 0, 0, 0, 0
             ],
         ]))