def test_iterator_single_dataset(self): r"""Tests iterating over a single dataset. """ data = MonoTextData(self._test_hparams) data_iterator = DataIterator(data) data_iterator.switch_to_dataset(dataset_name="data") iterator = data_iterator.get_iterator() i = 1001 for idx, batch in enumerate(iterator): self.assertEqual(batch.batch_size, self._test_hparams['batch_size']) np.testing.assert_array_equal(batch['length'], [1, 1]) for example in batch['text']: self.assertEqual(example[0], str(i)) i += 1 self.assertEqual(i, 2001)
def _test_modes_with_workers(self, lazy_mode: str, cache_mode: str, num_workers: int): from tqdm import tqdm gc.collect() mem = get_process_memory() with work_in_progress(f"Data loading with lazy mode '{lazy_mode}' " f"and cache mode '{cache_mode}' " f"with {num_workers} workers"): print(f"Memory before: {mem:.2f} MB") with work_in_progress("Construction"): data = ParallelData( self.source, '../../Downloads/src.vocab', '../../Downloads/tgt.vocab', { 'batch_size': self.batch_size, 'lazy_strategy': lazy_mode, 'cache_strategy': cache_mode, 'num_parallel_calls': num_workers, 'shuffle': False, 'allow_smaller_final_batch': False, 'max_dataset_size': 100000 }) print(f"Memory after construction: {mem:.2f} MB") iterator = DataIterator(data) with work_in_progress("Iteration"): for batch in tqdm(iterator, leave=False): self.assertEqual(batch.batch_size, self.batch_size) gc.collect() print(f"Memory after iteration: {mem:.2f} MB") with work_in_progress("2nd iteration"): for batch in tqdm(iterator, leave=False): self.assertEqual(batch.batch_size, self.batch_size)
def _test_modes_with_workers(self, lazy_mode: str, cache_mode: str, num_workers: int, parallelize_processing: bool = True, support_random_access: bool = False, shuffle: bool = False, **kwargs): hparams = { 'batch_size': self.batch_size, 'lazy_strategy': lazy_mode, 'cache_strategy': cache_mode, 'num_parallel_calls': num_workers, 'shuffle': shuffle, 'shuffle_buffer_size': self.size // 5, 'parallelize_processing': parallelize_processing, 'allow_smaller_final_batch': False, **kwargs, } numbers_data = [[x] * self.seq_len for x in range(self.size)] string_data = [ ' '.join(map(str, range(self.seq_len))) for _ in range(self.size) ] if not support_random_access: source = ZipDataSource( # type: ignore IterDataSource(numbers_data), SequenceDataSource(string_data)) else: source = ZipDataSource(SequenceDataSource(numbers_data), SequenceDataSource(string_data)) data = MockDataBase(source, hparams) # type: ignore iterator = DataIterator(data) if data._hparams.allow_smaller_final_batch: total_examples = self.size total_batches = (self.size + self.batch_size - 1) // self.batch_size else: total_examples = self.size // self.batch_size * self.batch_size total_batches = self.size // self.batch_size def check_batch(idx, batch): if idx == total_batches - 1: batch_size = (total_examples - 1) % self.batch_size + 1 else: batch_size = self.batch_size self.assertEqual(batch.numbers.shape, (batch_size, self.seq_len)) if not shuffle: numbers = np.asarray( [idx * self.batch_size + x + 1 for x in range(batch_size)]) self.assertTrue(np.all(batch.numbers == numbers[:, np.newaxis])) # check laziness if parallelize_processing: if lazy_mode == 'none': self.assertEqual(len(data._processed_cache), self.size) else: self.assertEqual(len(data._processed_cache), 0) if not support_random_access: if lazy_mode == 'process': self.assertEqual(len(data._cached_source._cache), self.size) else: self.assertEqual(len(data._cached_source._cache), 0) # first epoch cnt = 0 for idx, batch in enumerate(iterator): check_batch(idx, batch) cnt += 1 self.assertEqual(cnt, total_batches) # check cache if parallelize_processing: if cache_mode == 'none': self.assertEqual(len(data._processed_cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._processed_cache), 0) else: self.assertEqual(len(data._processed_cache), self.size) if lazy_mode != 'none' and not support_random_access: if cache_mode == 'none': self.assertEqual(len(data._cached_source._cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._cached_source._cache), self.size) else: self.assertEqual(len(data._cached_source._cache), 0) # second epoch cnt = 0 for idx, batch in enumerate(iterator): check_batch(idx, batch) cnt += 1 self.assertEqual(cnt, total_batches) # check again if parallelize_processing: if cache_mode == 'none': self.assertEqual(len(data._processed_cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._processed_cache), 0) else: self.assertEqual(len(data._processed_cache), self.size) if lazy_mode != 'none' and not support_random_access: if cache_mode == 'none': self.assertEqual(len(data._cached_source._cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._cached_source._cache), self.size) else: self.assertEqual(len(data._cached_source._cache), 0)
def test_iterator_multi_datasets(self): r"""Tests iterating over multiple datasets. """ train = MonoTextData(self._train_hparams) test = MonoTextData(self._test_hparams) train_batch_size = self._train_hparams["batch_size"] test_batch_size = self._test_hparams["batch_size"] data_iterator = DataIterator({"train": train, "test": test}) data_iterator.switch_to_dataset(dataset_name="train") iterator = data_iterator.get_iterator() for idx, val in enumerate(iterator): self.assertEqual(len(val), train_batch_size) number = idx * train_batch_size + 1 self.assertEqual(val.text[0], [str(number)]) # numbers: 1 - 2000, first 4 vocab entries are special tokens self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) data_iterator.switch_to_dataset(dataset_name="test") iterator = data_iterator.get_iterator() for idx, val in enumerate(iterator): self.assertEqual(len(val), test_batch_size) number = idx * test_batch_size + 1001 self.assertEqual(val.text[0], [str(number)]) self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) # test `get_iterator` interface for idx, val in enumerate(data_iterator.get_iterator('train')): self.assertEqual(len(val), train_batch_size) number = idx * train_batch_size + 1 self.assertEqual(val.text[0], [str(number)]) self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) # test exception for invalid dataset name with self.assertRaises(ValueError) as context: data_iterator.switch_to_dataset('val') self.assertTrue('not found' in str(context.exception))
def _run_and_test(self, hparams, discard_index=None): # Construct database text_data = tx.data.MultiAlignedData(hparams) self.assertEqual( text_data.vocab(0).size, self._vocab_size + len(text_data.vocab(0).special_tokens)) iterator = DataIterator(text_data) for batch in iterator: self.assertEqual(set(batch.keys()), set(text_data.list_items())) text_0 = batch['0_text'] text_1 = batch['1_text'] text_2 = batch['2_text'] int_3 = batch['label'] number_1 = batch['4_number1'] number_2 = batch['4_number2'] text_4 = batch['4_text'] for t0, t1, t2, i3, n1, n2, t4 in zip(text_0, text_1, text_2, int_3, number_1, number_2, text_4): np.testing.assert_array_equal(t0[:2], t1[1:3]) np.testing.assert_array_equal(t0[:3], t2[1:4]) if t0[0].startswith('This'): self.assertEqual(i3, 0) else: self.assertEqual(i3, 1) self.assertEqual(n1, 128) self.assertEqual(n2, 512) self.assertTrue(isinstance(n1, torch.Tensor)) self.assertTrue(isinstance(n2, torch.Tensor)) self.assertTrue(isinstance(t4, str)) if discard_index is not None: hpms = text_data._hparams.datasets[discard_index] max_l = hpms.max_seq_length max_l += sum( int(x is not None and x != '') for x in [ text_data.vocab(discard_index).bos_token, text_data.vocab(discard_index).eos_token ]) for i in range(2): for length in batch[text_data.length_name(i)]: self.assertLessEqual(length, max_l) # TODO(avinash): Add this back once variable utterance is added # for lengths in batch[text_data.length_name(2)]: # for length in lengths: # self.assertLessEqual(length, max_l) for i, hpms in enumerate(text_data._hparams.datasets): if hpms.data_type != "text": continue max_l = hpms.max_seq_length mode = hpms.length_filter_mode if max_l is not None and mode == "truncate": max_l += sum( int(x is not None and x != '') for x in [ text_data.vocab(i).bos_token, text_data.vocab(i).eos_token ]) for length in batch[text_data.length_name(i)]: self.assertLessEqual(length, max_l)