def setUp(self): super(KittiDataTest, self).setUp() self._label_file = test_helper.test_src_dir_path( 'tasks/car/testdata/kitti_raw_label_testdata.txt') self._calib_file = test_helper.test_src_dir_path( 'tasks/car/testdata/kitti_raw_calib_testdata.txt')
def _HistFile(self): return test_helper.test_src_dir_path('core/ops/testdata/history.txt')
def _Params(self): p = tokenizers.SentencePieceTokenizer.Params() p.spm_model = test_helper.test_src_dir_path( 'core/testdata/en-1k.spm.model') p.vocab_size = 1024 return p
def _TfEventFile(self): return test_helper.test_src_dir_path( 'core/ops/testdata/events.out.tfevents.test')
def _BleuFile(self): return test_helper.test_src_dir_path('core/ops/testdata/history_bleu.txt')
def testTextPackedInputTextPacking(self): p = input_generator.TextPackedInput.Params() p.flush_every_n = 0 p.require_sequential_order = True p.file_pattern = 'text:' + test_helper.test_src_dir_path( 'tasks/mt/testdata/en_de.text') p.tokenizer = tokenizers.WpmTokenizer.Params().Set( vocab_filepath=test_helper.test_src_dir_path( 'tasks/mt/wpm-ende-2k.voc'), vocab_size=2000) # We repeat the 2-line file twice for a batch of 2, each packing both lines. p.repeat_count = 2 p.source_max_length = 16 p.target_max_length = 20 p.bucket_batch_limit = [2] p.packing_factor = 2 with self.session() as sess: inp = p.Instantiate() batch_tensor = inp.GetPreprocessedInputBatch() batch, num_examples = sess.run( [batch_tensor, inp.GlobalBatchSize()]) self.assertEqual(num_examples, 4) self.assertAllEqual( batch.src.ids, np.array([ [ 109, 251, 98, 595, 1009, 245, 326, 129, 4, 2, 115, 276, 18, 66, 2, 0 ], [ 115, 276, 18, 66, 2, 109, 251, 98, 595, 1009, 245, 326, 129, 4, 2, 0 ], ])) self.assertAllEqual( batch.src.segment_ids, np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0], [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0]], dtype=np.float32)) self.assertAllEqual( batch.src.segment_pos, np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0], [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]])) self.assertAllEqual( batch.src.strs, np.array([ b'Too much has changed.\tHello!', b'Hello!\tToo much has changed.' ])) self.assertAllEqual( batch.tgt.ids, np.array([ [ 1, 197, 446, 458, 419, 284, 323, 1411, 571, 456, 409, 13, 4, 1, 115, 281, 18, 66, 0, 0 ], [ 1, 115, 281, 18, 66, 1, 197, 446, 458, 419, 284, 323, 1411, 571, 456, 409, 13, 4, 0, 0 ], ])) self.assertAllEqual( batch.tgt.labels, np.array([ [ 197, 446, 458, 419, 284, 323, 1411, 571, 456, 409, 13, 4, 2, 115, 281, 18, 66, 2, 0, 0 ], [ 115, 281, 18, 66, 2, 197, 446, 458, 419, 284, 323, 1411, 571, 456, 409, 13, 4, 2, 0, 0 ], ])) self.assertAllEqual( batch.tgt.segment_ids, np.array( [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0], [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0]], dtype=np.float32)) self.assertAllEqual( batch.tgt.segment_pos, np.array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 1, 2, 3, 4, 0, 0 ], [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 0]])) self.assertAllEqual( batch.tgt.strs, np.array([ b'Daf\xc3\xbcr hat sich zu viel ver\xc3\xa4ndert.\tHallo!', b'Hallo!\tDaf\xc3\xbcr hat sich zu viel ver\xc3\xa4ndert.' ]))
def testTextPackedInputProto(self): p = input_generator.TextPackedInput.Params() p.flush_every_n = 0 p.require_sequential_order = True p.repeat_count = 1 p.file_pattern = 'tfrecord:' + test_helper.test_src_dir_path( 'tasks/mt/testdata/en_fr.tfrecord') p.pad_to_max_seq_length = True p.tokenizer = tokenizers.AsciiTokenizer.Params() p.input_file_type = 'sentence_proto' p.source_max_length = 22 p.target_max_length = 24 p.bucket_batch_limit = [2] with self.session() as sess: inp = p.Instantiate() batch_tensor = inp.GetPreprocessedInputBatch() for k, x in batch_tensor.FlattenItems(): self.assertTrue(x.shape.is_fully_defined(), k) batch, num_examples = sess.run( [batch_tensor, inp.GlobalBatchSize()]) self.assertEqual(num_examples, 2) self.assertEqual(len(batch.src), 7) self.assertAllEqual(batch.src.strs, [b'I love paragliding!', b'vol biv paragliding']) self.assertAllEqual(batch.tgt.strs, [b"J'adore le parapente!", b'vol biv parapente']) self.assertAllEqual( batch.src.ids, np.array([ [ 13, 3, 16, 19, 26, 9, 3, 20, 5, 22, 5, 11, 16, 13, 8, 13, 18, 11, 35, 2, 0, 0 ], [ 26, 19, 16, 3, 6, 13, 26, 3, 20, 5, 22, 5, 11, 16, 13, 8, 13, 18, 11, 2, 0, 0 ], ])) self.assertAllEqual( batch.tgt.ids, np.array([ [ 1, 14, 32, 5, 8, 19, 22, 9, 3, 16, 9, 3, 20, 5, 22, 5, 20, 9, 18, 24, 9, 35, 0, 0 ], [ 1, 26, 19, 16, 3, 6, 13, 26, 3, 20, 5, 22, 5, 20, 9, 18, 24, 9, 0, 0, 0, 0, 0, 0 ], ])) self.assertAllEqual( batch.tgt.labels, np.array([ [ 14, 32, 5, 8, 19, 22, 9, 3, 16, 9, 3, 20, 5, 22, 5, 20, 9, 18, 24, 9, 35, 2, 0, 0 ], [ 26, 19, 16, 3, 6, 13, 26, 3, 20, 5, 22, 5, 20, 9, 18, 24, 9, 2, 0, 0, 0, 0, 0, 0 ], ]))