Exemple #1
0
 def test_not_packed_not_batched(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }, {
       'inputs': [10, 20, 30],
       'targets': [110, 220, 330],
   }, {
       'inputs': [40, 50, 60],
       'targets': [440, 550, 660],
   }, {
       'inputs': [70, 80, 90],
       'targets': [770, 880, 990],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=1,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=False,
       pack=False)
   self.assertEqual(data.get_ds_metrics(pnb_ds), (6, 6, 1, 3))
Exemple #2
0
 def test_packed_not_batched_errors(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }, {
       'inputs': [10, 20, 30],
       'targets': [110, 220, 330],
   }, {
       'inputs': [40, 50, 60],
       'targets': [440, 550, 660],
   }, {
       'inputs': [70, 80, 90],
       'targets': [770, 880, 990],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=False,
       pack=False)
   with self.assertRaises(NotImplementedError):
     _ = data.get_ds_metrics(pnb_ds)
Exemple #3
0
 def test_dropped_full_final_batch(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }, {
       'inputs': [10, 20, 30],
       'targets': [110, 220, 330],
   }, {
       'inputs': [40, 50, 60],
       'targets': [440, 550, 660],
   }, {
       'inputs': [70, 80, 90],
       'targets': [770, 880, 990],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=True,
       pack=True)
   self.assertEqual(data.get_ds_final_batch_size(pnb_ds), (1, 2))
Exemple #4
0
 def test_packed_and_batched(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }, {
       'inputs': [10, 20, 30],
       'targets': [110, 220, 330],
   }, {
       'inputs': [40, 50, 60],
       'targets': [440, 550, 660],
   }, {
       'inputs': [70, 80, 90],
       'targets': [770, 880, 990],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=False,
       pack=True)
   pnb_ds_iter = iter(pnb_ds)
   self.assertEqual(data.get_unique_examples(next(pnb_ds_iter)), 4)
Exemple #5
0
 def test_drop_remainder_drops_final_batch(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }, {
       'inputs': [10, 20, 30],
       'targets': [110, 220, 330],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=True,
       pack=True)
   pnb_ds_iter = iter(pnb_ds)
   shape_of_first_batch = next(pnb_ds_iter)['inputs'].shape
   self.assertEqual(shape_of_first_batch, (2, 7))
   with self.assertRaises(StopIteration):
     _ = next(pnb_ds_iter)['inputs'].shape
Exemple #6
0
 def test_pack_repeats_only_inputs_targets(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
       'edits': [0, 1, 2],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
       'edits': [0, 1, 2],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
       'edits': [0, 1, 2],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=False,
       pack=True)
   self.assertCountEqual(
       list(pnb_ds.element_spec.keys()),
       [
           'edits', 'inputs', 'targets', 'inputs_position', 'targets_position',
           'targets_segmentation', 'inputs_segmentation'
       ],
   )
Exemple #7
0
 def test_empty_batch(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=False,
       drop_remainder=True,
       pack=True)
   self.assertEqual(data.get_ds_final_batch_size(pnb_ds), (0, 0))
Exemple #8
0
 def test_pack_pads_to_max_len(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=8,
       extend_to_fill=False,
       drop_remainder=False,
       pack=True)
   shape_of_single_batch = next(iter(pnb_ds))['inputs'].shape
   self.assertEqual(shape_of_single_batch, (2, 8))
Exemple #9
0
 def test_extend_to_fill_extends_final_batch(self):
   raw_examples = [{
       'inputs': [1, 2, 3],
       'targets': [11, 22, 33],
   }, {
       'inputs': [4, 5, 6],
       'targets': [44, 55, 66],
   }, {
       'inputs': [7, 8, 9],
       'targets': [77, 88, 99],
   }, {
       'inputs': [10, 20, 30],
       'targets': [110, 220, 330],
   }, {
       'inputs': [40, 50, 60],
       'targets': [440, 550, 660],
   }, {
       'inputs': [70, 80, 90],
       'targets': [770, 880, 990],
   }]
   ds = _get_ds_from_examples(raw_examples)
   pnb_ds = data.pack_and_batch_ds(
       ds,
       batch_size=2,
       max_length=7,
       extend_to_fill=True,
       drop_remainder=False,
       pack=True)
   pnb_ds_iter = iter(pnb_ds)
   first_batch_inputs = next(pnb_ds_iter)['inputs']
   final_batch_inputs = next(pnb_ds_iter)['inputs']
   shape_of_first_batch = first_batch_inputs.shape
   shape_of_final_batch = final_batch_inputs.shape
   self.assertEqual(shape_of_first_batch, (2, 7))
   self.assertListEqual(first_batch_inputs.numpy().tolist(),
                        [[1, 2, 3, 4, 5, 6, 0], [7, 8, 9, 10, 20, 30, 0]])
   self.assertEqual(shape_of_final_batch, (2, 7))
   self.assertListEqual(final_batch_inputs.numpy().tolist(),
                        [[40, 50, 60, 70, 80, 90, 0], [1, 2, 3, 4, 5, 6, 0]])