def test_iterator_single_dataset(self):
     r"""Tests iterating over a single dataset.
     """
     data = MonoTextData(self._test_hparams)
     data_iterator = DataIterator(data)
     data_iterator.switch_to_dataset(dataset_name="data")
     iterator = data_iterator.get_iterator()
     i = 1001
     for idx, batch in enumerate(iterator):
         self.assertEqual(batch.batch_size, self._test_hparams['batch_size'])
         np.testing.assert_array_equal(batch['length'], [1, 1])
         for example in batch['text']:
             self.assertEqual(example[0], str(i))
             i += 1
     self.assertEqual(i, 2001)
예제 #2
0
 def _test_modes_with_workers(self, lazy_mode: str, cache_mode: str,
                              num_workers: int):
     from tqdm import tqdm
     gc.collect()
     mem = get_process_memory()
     with work_in_progress(f"Data loading with lazy mode '{lazy_mode}' "
                           f"and cache mode '{cache_mode}' "
                           f"with {num_workers} workers"):
         print(f"Memory before: {mem:.2f} MB")
         with work_in_progress("Construction"):
             data = ParallelData(
                 self.source, '../../Downloads/src.vocab',
                 '../../Downloads/tgt.vocab', {
                     'batch_size': self.batch_size,
                     'lazy_strategy': lazy_mode,
                     'cache_strategy': cache_mode,
                     'num_parallel_calls': num_workers,
                     'shuffle': False,
                     'allow_smaller_final_batch': False,
                     'max_dataset_size': 100000
                 })
         print(f"Memory after construction: {mem:.2f} MB")
         iterator = DataIterator(data)
         with work_in_progress("Iteration"):
             for batch in tqdm(iterator, leave=False):
                 self.assertEqual(batch.batch_size, self.batch_size)
         gc.collect()
         print(f"Memory after iteration: {mem:.2f} MB")
         with work_in_progress("2nd iteration"):
             for batch in tqdm(iterator, leave=False):
                 self.assertEqual(batch.batch_size, self.batch_size)
예제 #3
0
    def _test_modes_with_workers(self,
                                 lazy_mode: str,
                                 cache_mode: str,
                                 num_workers: int,
                                 parallelize_processing: bool = True,
                                 support_random_access: bool = False,
                                 shuffle: bool = False,
                                 **kwargs):
        hparams = {
            'batch_size': self.batch_size,
            'lazy_strategy': lazy_mode,
            'cache_strategy': cache_mode,
            'num_parallel_calls': num_workers,
            'shuffle': shuffle,
            'shuffle_buffer_size': self.size // 5,
            'parallelize_processing': parallelize_processing,
            'allow_smaller_final_batch': False,
            **kwargs,
        }
        numbers_data = [[x] * self.seq_len for x in range(self.size)]
        string_data = [
            ' '.join(map(str, range(self.seq_len))) for _ in range(self.size)
        ]
        if not support_random_access:
            source = ZipDataSource(  # type: ignore
                IterDataSource(numbers_data), SequenceDataSource(string_data))
        else:
            source = ZipDataSource(SequenceDataSource(numbers_data),
                                   SequenceDataSource(string_data))
        data = MockDataBase(source, hparams)  # type: ignore
        iterator = DataIterator(data)

        if data._hparams.allow_smaller_final_batch:
            total_examples = self.size
            total_batches = (self.size + self.batch_size -
                             1) // self.batch_size
        else:
            total_examples = self.size // self.batch_size * self.batch_size
            total_batches = self.size // self.batch_size

        def check_batch(idx, batch):
            if idx == total_batches - 1:
                batch_size = (total_examples - 1) % self.batch_size + 1
            else:
                batch_size = self.batch_size
            self.assertEqual(batch.numbers.shape, (batch_size, self.seq_len))
            if not shuffle:
                numbers = np.asarray(
                    [idx * self.batch_size + x + 1 for x in range(batch_size)])
                self.assertTrue(np.all(batch.numbers == numbers[:,
                                                                np.newaxis]))

        # check laziness
        if parallelize_processing:
            if lazy_mode == 'none':
                self.assertEqual(len(data._processed_cache), self.size)
            else:
                self.assertEqual(len(data._processed_cache), 0)
                if not support_random_access:
                    if lazy_mode == 'process':
                        self.assertEqual(len(data._cached_source._cache),
                                         self.size)
                    else:
                        self.assertEqual(len(data._cached_source._cache), 0)

        # first epoch
        cnt = 0
        for idx, batch in enumerate(iterator):
            check_batch(idx, batch)
            cnt += 1
        self.assertEqual(cnt, total_batches)

        # check cache
        if parallelize_processing:
            if cache_mode == 'none':
                self.assertEqual(len(data._processed_cache), 0)
            elif cache_mode == 'loaded':
                self.assertEqual(len(data._processed_cache), 0)
            else:
                self.assertEqual(len(data._processed_cache), self.size)
            if lazy_mode != 'none' and not support_random_access:
                if cache_mode == 'none':
                    self.assertEqual(len(data._cached_source._cache), 0)
                elif cache_mode == 'loaded':
                    self.assertEqual(len(data._cached_source._cache),
                                     self.size)
                else:
                    self.assertEqual(len(data._cached_source._cache), 0)

        # second epoch
        cnt = 0
        for idx, batch in enumerate(iterator):
            check_batch(idx, batch)
            cnt += 1
        self.assertEqual(cnt, total_batches)

        # check again
        if parallelize_processing:
            if cache_mode == 'none':
                self.assertEqual(len(data._processed_cache), 0)
            elif cache_mode == 'loaded':
                self.assertEqual(len(data._processed_cache), 0)
            else:
                self.assertEqual(len(data._processed_cache), self.size)
            if lazy_mode != 'none' and not support_random_access:
                if cache_mode == 'none':
                    self.assertEqual(len(data._cached_source._cache), 0)
                elif cache_mode == 'loaded':
                    self.assertEqual(len(data._cached_source._cache),
                                     self.size)
                else:
                    self.assertEqual(len(data._cached_source._cache), 0)
예제 #4
0
    def test_iterator_multi_datasets(self):
        r"""Tests iterating over multiple datasets.
        """
        train = MonoTextData(self._train_hparams)
        test = MonoTextData(self._test_hparams)
        train_batch_size = self._train_hparams["batch_size"]
        test_batch_size = self._test_hparams["batch_size"]
        data_iterator = DataIterator({"train": train, "test": test})
        data_iterator.switch_to_dataset(dataset_name="train")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            # numbers: 1 - 2000, first 4 vocab entries are special tokens
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        data_iterator.switch_to_dataset(dataset_name="test")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test `get_iterator` interface
        for idx, val in enumerate(data_iterator.get_iterator('train')):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test exception for invalid dataset name
        with self.assertRaises(ValueError) as context:
            data_iterator.switch_to_dataset('val')
        self.assertTrue('not found' in str(context.exception))
예제 #5
0
    def _run_and_test(self, hparams, discard_index=None):
        # Construct database
        text_data = tx.data.MultiAlignedData(hparams)
        self.assertEqual(
            text_data.vocab(0).size,
            self._vocab_size + len(text_data.vocab(0).special_tokens))

        iterator = DataIterator(text_data)
        for batch in iterator:
            self.assertEqual(set(batch.keys()), set(text_data.list_items()))
            text_0 = batch['0_text']
            text_1 = batch['1_text']
            text_2 = batch['2_text']
            int_3 = batch['label']
            number_1 = batch['4_number1']
            number_2 = batch['4_number2']
            text_4 = batch['4_text']

            for t0, t1, t2, i3, n1, n2, t4 in zip(text_0, text_1, text_2,
                                                  int_3, number_1, number_2,
                                                  text_4):

                np.testing.assert_array_equal(t0[:2], t1[1:3])
                np.testing.assert_array_equal(t0[:3], t2[1:4])
                if t0[0].startswith('This'):
                    self.assertEqual(i3, 0)
                else:
                    self.assertEqual(i3, 1)
                self.assertEqual(n1, 128)
                self.assertEqual(n2, 512)
                self.assertTrue(isinstance(n1, torch.Tensor))
                self.assertTrue(isinstance(n2, torch.Tensor))
                self.assertTrue(isinstance(t4, str))

            if discard_index is not None:
                hpms = text_data._hparams.datasets[discard_index]
                max_l = hpms.max_seq_length
                max_l += sum(
                    int(x is not None and x != '') for x in [
                        text_data.vocab(discard_index).bos_token,
                        text_data.vocab(discard_index).eos_token
                    ])
                for i in range(2):
                    for length in batch[text_data.length_name(i)]:
                        self.assertLessEqual(length, max_l)

                # TODO(avinash): Add this back once variable utterance is added
                # for lengths in batch[text_data.length_name(2)]:
                #    for length in lengths:
                #        self.assertLessEqual(length, max_l)

            for i, hpms in enumerate(text_data._hparams.datasets):
                if hpms.data_type != "text":
                    continue
                max_l = hpms.max_seq_length
                mode = hpms.length_filter_mode
                if max_l is not None and mode == "truncate":
                    max_l += sum(
                        int(x is not None and x != '') for x in [
                            text_data.vocab(i).bos_token,
                            text_data.vocab(i).eos_token
                        ])
                    for length in batch[text_data.length_name(i)]:
                        self.assertLessEqual(length, max_l)