コード例 #1
0
    def test_bucketing(self):
        r"""Tests bucketing.
        """
        hparams = copy.deepcopy(self._hparams)
        hparams.update({
            "bucket_boundaries": [7],
            "bucket_batch_sizes": [6, 4]
        })

        text_data = MonoTextData(hparams)
        iterator = DataIterator(text_data)

        hparams.update({
            "bucket_boundaries": [7],
            "bucket_batch_sizes": [7, 7],
            "allow_smaller_final_batch": False
        })

        text_data_1 = MonoTextData(hparams)
        iterator_1 = DataIterator(text_data_1)

        for data_batch, data_batch_1 in zip(iterator, iterator_1):
            length = data_batch['length'][0]
            if length < 7:
                last_batch_size = hparams['num_epochs'] % 6
                self.assertTrue(
                    len(data_batch['text']) == 6
                    or len(data_batch['text']) == last_batch_size)
            else:
                last_batch_size = hparams['num_epochs'] % 4
                self.assertTrue(
                    len(data_batch['text']) == 4
                    or len(data_batch['text']) == last_batch_size)

            self.assertEqual(len(data_batch_1['text']), 7)
コード例 #2
0
 def test_iterator_single_dataset(self):
     r"""Tests iterating over a single dataset.
     """
     data = MonoTextData(self._test_hparams)
     data_iterator = DataIterator(data)
     data_iterator.switch_to_dataset(dataset_name="data")
     iterator = data_iterator.get_iterator()
     i = 1001
     for idx, batch in enumerate(iterator):
         self.assertEqual(batch.batch_size, self._test_hparams['batch_size'])
         np.testing.assert_array_equal(batch['length'], [1, 1])
         for example in batch['text']:
             self.assertEqual(example[0], str(i))
             i += 1
     self.assertEqual(i, 2001)
コード例 #3
0
    def _run_and_test(self, hparams):
        # Construct database
        text_data = MonoTextData(hparams)
        self.assertEqual(
            text_data.vocab.size,
            self._vocab_size + len(text_data.vocab.special_tokens))

        iterator = DataIterator(text_data)

        for data_batch in iterator:
            # Run the logics
            self.assertEqual(set(data_batch.keys()),
                             set(text_data.list_items()))

            # Test utterance count
            utt_ind = np.sum(data_batch["text_ids"], 2) != 0
            utt_cnt = np.sum(utt_ind, 1)
            self.assertListEqual(
                data_batch[text_data.utterance_cnt_name].tolist(),
                utt_cnt.tolist())

            if text_data.hparams.dataset.pad_to_max_seq_length:
                max_l = text_data.hparams.dataset.max_seq_length
                max_l += text_data._decoder.added_length
                for x in data_batch['text']:
                    for xx in x:
                        self.assertEqual(len(xx), max_l)
                for x in data_batch['text_ids']:
                    for xx in x:
                        self.assertEqual(len(xx), max_l)
コード例 #4
0
    def test_dynamic_batching(self):
        r"""Tests dynamic batching using :class:`texar.torch.data.BatchingStrategy`.
        """
        sent_lengths = np.random.randint(10, 20, size=(100, ))
        sentences = [['a'] * length for length in sent_lengths]
        data_source = SequenceDataSource(sentences)

        class CustomData(DataBase):
            def __init__(self, source):
                super().__init__(source)

            def process(self, raw_example):
                return raw_example

            def collate(self, examples):
                return Batch(len(examples), text=examples)

        train_data = CustomData(data_source)

        batch_size = 5
        max_tokens = 75
        strategy = TokenCountBatchingStrategy(max_tokens, batch_size, len)
        iterator = DataIterator(train_data, strategy)

        for batch in iterator:
            self.assertLessEqual(len(batch), batch_size)
            self.assertLessEqual(sum(len(s) for s in batch.text), max_tokens)
コード例 #5
0
 def _test_modes_with_workers(self, lazy_mode: str, cache_mode: str,
                              num_workers: int):
     from tqdm import tqdm
     gc.collect()
     mem = get_process_memory()
     with work_in_progress(f"Data loading with lazy mode '{lazy_mode}' "
                           f"and cache mode '{cache_mode}' "
                           f"with {num_workers} workers"):
         print(f"Memory before: {mem:.2f} MB")
         with work_in_progress("Construction"):
             data = ParallelData(self.source,
                                 '../../Downloads/src.vocab',
                                 '../../Downloads/tgt.vocab',
                                 {'batch_size': self.batch_size,
                                  'lazy_strategy': lazy_mode,
                                  'cache_strategy': cache_mode,
                                  'num_parallel_calls': num_workers,
                                  'shuffle': False,
                                  'allow_smaller_final_batch': False,
                                  'max_dataset_size': 100000})
         print(f"Memory after construction: {mem:.2f} MB")
         iterator = DataIterator(data)
         with work_in_progress("Iteration"):
             for batch in tqdm(iterator, leave=False):
                 self.assertEqual(batch.batch_size, self.batch_size)
         gc.collect()
         print(f"Memory after iteration: {mem:.2f} MB")
         with work_in_progress("2nd iteration"):
             for batch in tqdm(iterator, leave=False):
                 self.assertEqual(batch.batch_size, self.batch_size)
コード例 #6
0
    def setUp(self):
        self.batch_size = 4
        self.num_label = len(pathologies)
        data_hparams = {
            "datasource": {
                "img_root": "tests/test_iu_xray_data/iu_xray_images",
                "text_root": "tests/test_iu_xray_data/text_root",
                "vocab_path": "tests/test_iu_xray_data/test_vocab.txt",
                "transforms": transforms,
                "pathologies": pathologies,
            },
            "batch_size": self.batch_size,
            "shuffle": False,
        }
        dataset = IU_XRay_Dataset(data_hparams)
        dataset.to(torch.device('cpu'))
        self.loader = DataIterator(dataset)

        self.extractor = SimpleFusionEncoder()
        mlc_hparam = {
            'num_tags': len(pathologies),
        }
        self.mlc = MLC(mlc_hparam)
        self.mlc_trainer = MLCTrainer(mlc_hparam)

        self.loss = torch.nn.BCEWithLogitsLoss()
コード例 #7
0
    def _run_and_test(self, hparams):
        # Construct database
        record_data = RecordData(hparams)
        iterator = DataIterator(record_data)

        def _prod(lst):
            res = 1
            for i in lst:
                res *= i
            return res

        for idx, data_batch in enumerate(iterator):
            self.assertEqual(set(data_batch.keys()),
                             set(record_data.list_items()))

            # Check data consistency
            for key in self._unconvert_features:
                value = data_batch[key][0]
                self.assertEqual(value, self._dataset_valid[key][idx])
            self.assertEqual(list(data_batch['shape'][0]),
                             list(self._dataset_valid['shape'][idx]))

            # Check data type conversion
            for key, item in self._feature_convert_types.items():
                dtype = get_numpy_dtype(item)
                value = data_batch[key][0]
                if dtype is np.str_:
                    self.assertIsInstance(value, str)
                elif dtype is np.bytes_:
                    self.assertIsInstance(value, bytes)
                else:
                    if isinstance(value, torch.Tensor):
                        value_dtype = get_numpy_dtype(value.dtype)
                    else:
                        value_dtype = value.dtype
                    dtype_matched = np.issubdtype(value_dtype, dtype)
                    self.assertTrue(dtype_matched)

            # Check image decoding and resize
            if hparams["dataset"].get("image_options"):
                image_options = hparams["dataset"].get("image_options")
                if isinstance(image_options, dict):
                    image_options = [image_options]
                for image_option_feature in image_options:
                    image_key = image_option_feature.get("image_feature_name")
                    if image_key is None:
                        continue
                    image_gen = data_batch[image_key][0]
                    image_valid_shape = self._dataset_valid["shape"][idx]
                    resize_height = image_option_feature.get("resize_height")
                    resize_width = image_option_feature.get("resize_width")
                    if resize_height and resize_width:
                        self.assertEqual(
                            image_gen.shape[0] * image_gen.shape[1],
                            resize_height * resize_width)
                    else:
                        self.assertEqual(_prod(image_gen.shape),
                                         _prod(image_valid_shape))
コード例 #8
0
    def _run_and_test(self,
                      hparams,
                      test_batch_size=False,
                      test_transform=False):
        # Construct database
        text_data = MonoTextData(hparams)
        self.assertEqual(
            text_data.vocab.size,
            self._vocab_size + len(text_data.vocab.special_tokens))

        iterator = DataIterator(text_data)

        for data_batch in iterator:
            self.assertEqual(set(data_batch.keys()),
                             set(text_data.list_items()))

            if test_batch_size:
                self.assertEqual(len(data_batch['text']),
                                 hparams['batch_size'])

            if test_transform:
                for i in range(len(data_batch['text'])):
                    text_ = data_batch['text'][i]
                    self.assertTrue(text_ in self.upper_cased_text)

            max_seq_length = text_data.hparams.dataset.max_seq_length
            mode = text_data.hparams.dataset.length_filter_mode

            max_l = max_seq_length
            if max_seq_length is not None:
                if text_data.hparams.dataset.eos_token != '':
                    max_l += 1
                if text_data.hparams.dataset.bos_token != '':
                    max_l += 1

            if max_seq_length == 6:
                for length in data_batch['length']:
                    self.assertLessEqual(length, max_l)
                if mode == "discard":
                    for length in data_batch['length']:
                        self.assertEqual(length, 5)
                elif mode == "truncate":
                    num_length_6 = 0
                    for length in data_batch['length']:
                        num_length_6 += int(length == 6)
                    self.assertGreater(num_length_6, 0)
                else:
                    raise ValueError("Unknown mode: %s" % mode)

            if text_data.hparams.dataset.pad_to_max_seq_length:
                for x in data_batch['text']:
                    self.assertEqual(len(x), max_l)
                for x in data_batch['text_ids']:
                    self.assertEqual(len(x), max_l)
コード例 #9
0
    def _run_and_test(self,
                      hparams,
                      proc_shr=False,
                      test_transform=None,
                      discard_src=False):
        # Construct database
        text_data = PairedTextData(hparams)
        self.assertEqual(
            text_data.source_vocab.size,
            self._vocab_size + len(text_data.source_vocab.special_tokens))

        iterator = DataIterator(text_data)
        for data_batch in iterator:
            self.assertEqual(set(data_batch.keys()),
                             set(text_data.list_items()))

            if proc_shr:
                tgt_eos = '<EOS>'
            else:
                tgt_eos = '<TARGET_EOS>'

            # Test matching
            src_text = data_batch['source_text']
            tgt_text = data_batch['target_text']
            if proc_shr:
                for src, tgt in zip(src_text, tgt_text):
                    np.testing.assert_array_equal(src[:3], tgt[:3])
            else:
                for src, tgt in zip(src_text, tgt_text):
                    np.testing.assert_array_equal(src[:3], tgt[1:4])
            self.assertTrue(tgt_eos in data_batch['target_text'][0])

            if test_transform:
                for i in range(len(data_batch['source_text'])):
                    text_ = data_batch['source_text'][i]
                    self.assertTrue(text_ in self.src_upper_cased_text)
                for i in range(len(data_batch['target_text'])):
                    text_ = data_batch['target_text'][i]
                    self.assertTrue(text_ in self.tgt_upper_cased_text)

            if discard_src:
                src_hparams = text_data.hparams.source_dataset
                max_l = src_hparams.max_seq_length
                max_l += sum(
                    int(x is not None and x != '') for x in
                    [text_data._src_bos_token, text_data._tgt_bos_token])
                for l in data_batch["source_length"]:
                    self.assertLessEqual(l, max_l)
コード例 #10
0
    def test_auto_storage_moving(self):
        cuda_tensors = set()

        def move_tensor(tensor, device, non_blocking=False):
            if isinstance(device, torch.device) and device.type == "cuda":
                self.assertTrue(non_blocking)
                cuda_tensors.add(id(tensor))
            return tensor

        device = torch.device("cuda:0")

        with patch.object(torch.Tensor, "to", move_tensor):
            train = MonoTextData(self._train_hparams, device=device)
            iterator = DataIterator(train)
            for batch in iterator:
                self.assertTrue(id(batch.text_ids) in cuda_tensors)
                self.assertTrue(id(batch.length) in cuda_tensors)
コード例 #11
0
    def setUp(self):
        self.batch_size = 4
        self.num_label = len(pathologies)
        hparams = {
            "datasource": {
                "img_root": "tests/test_iu_xray_data/iu_xray_images",
                "text_root": "tests/test_iu_xray_data/text_root",
                "vocab_path": "tests/test_iu_xray_data/test_vocab.txt",
                "transforms": transforms,
                "pathologies": pathologies,
            },
            "batch_size": self.batch_size,
            "shuffle": False,
        }
        dataset = IU_XRay_Dataset(hparams)
        dataset.to(torch.device('cpu'))
        self.vocab = dataset.source.vocab

        self.ground_truth_keys = [
            'img_tensor', 'label', 'token_tensor', 'stop_prob'
        ]
        self.ground_truth_findings = [
            'cardiac and mediastinal contours '
            'are within normal limits <EOS> the '
            'lungs are clear <EOS> bony structures '
            'are intact <EOS>'
        ]

        self.ground_truth_token_tensors = torch.Tensor([
            [61, 10, 36, 55, 7, 25, 8, 28, 2, 0, 0, 0],
            [5, 19, 7, 21, 2, 0, 0, 0, 0, 0, 0, 0],
            [58, 52, 7, 73, 2, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ]).to(torch.long)
        self.loader = DataIterator(dataset)
コード例 #12
0
    def _run_and_test(self, hparams, discard_index=None):
        # Construct database
        text_data = MultiAlignedData(hparams)
        self.assertEqual(
            text_data.vocab(0).size,
            self._vocab_size + len(text_data.vocab(0).special_tokens))

        iterator = DataIterator(text_data)
        for batch in iterator:
            self.assertEqual(set(batch.keys()), set(text_data.list_items()))
            text_0 = batch['0_text']
            text_1 = batch['1_text']
            text_2 = batch['2_text']
            int_3 = batch['label']
            number_1 = batch['4_number1']
            number_2 = batch['4_number2']
            text_4 = batch['4_text']

            for t0, t1, t2, i3, n1, n2, t4 in zip(text_0, text_1, text_2,
                                                  int_3, number_1, number_2,
                                                  text_4):

                np.testing.assert_array_equal(t0[:2], t1[1:3])
                np.testing.assert_array_equal(t0[:3], t2[1:4])
                if t0[0].startswith('This'):
                    self.assertEqual(i3, 0)
                else:
                    self.assertEqual(i3, 1)
                self.assertEqual(n1, 128)
                self.assertEqual(n2, 512)
                self.assertTrue(isinstance(n1, torch.Tensor))
                self.assertTrue(isinstance(n2, torch.Tensor))
                self.assertTrue(isinstance(t4, str))

            if discard_index is not None:
                hpms = text_data._hparams.datasets[discard_index]
                max_l = hpms.max_seq_length
                max_l += sum(
                    int(x is not None and x != '') for x in [
                        text_data.vocab(discard_index).bos_token,
                        text_data.vocab(discard_index).eos_token
                    ])
                for i in range(2):
                    for length in batch[text_data.length_name(i)]:
                        self.assertLessEqual(length, max_l)

                # TODO(avinash): Add this back once variable utterance is added
                # for lengths in batch[text_data.length_name(2)]:
                #    for length in lengths:
                #        self.assertLessEqual(length, max_l)

            for i, hpms in enumerate(text_data._hparams.datasets):
                if hpms.data_type != "text":
                    continue
                max_l = hpms.max_seq_length
                mode = hpms.length_filter_mode
                if max_l is not None and mode == "truncate":
                    max_l += sum(
                        int(x is not None and x != '') for x in [
                            text_data.vocab(i).bos_token,
                            text_data.vocab(i).eos_token
                        ])
                    for length in batch[text_data.length_name(i)]:
                        self.assertLessEqual(length, max_l)
コード例 #13
0
    def _test_modes_with_workers(self,
                                 lazy_mode: str,
                                 cache_mode: str,
                                 num_workers: int,
                                 parallelize_processing: bool = True,
                                 support_random_access: bool = False,
                                 shuffle: bool = False,
                                 **kwargs):
        hparams = {
            'batch_size': self.batch_size,
            'lazy_strategy': lazy_mode,
            'cache_strategy': cache_mode,
            'num_parallel_calls': num_workers,
            'shuffle': shuffle,
            'shuffle_buffer_size': self.size // 5,
            'parallelize_processing': parallelize_processing,
            'allow_smaller_final_batch': False,
            **kwargs,
        }
        numbers_data = [[x] * self.seq_len for x in range(self.size)]
        string_data = [
            ' '.join(map(str, range(self.seq_len))) for _ in range(self.size)
        ]
        if not support_random_access:
            source = ZipDataSource(  # type: ignore
                IterDataSource(numbers_data), SequenceDataSource(string_data))
        else:
            source = ZipDataSource(SequenceDataSource(numbers_data),
                                   SequenceDataSource(string_data))
        data = MockDataBase(source, hparams)  # type: ignore
        iterator = DataIterator(data)

        if data._hparams.allow_smaller_final_batch:
            total_examples = self.size
            total_batches = (self.size + self.batch_size -
                             1) // self.batch_size
        else:
            total_examples = self.size // self.batch_size * self.batch_size
            total_batches = self.size // self.batch_size

        def check_batch(idx, batch):
            if idx == total_batches - 1:
                batch_size = (total_examples - 1) % self.batch_size + 1
            else:
                batch_size = self.batch_size
            self.assertEqual(batch.numbers.shape, (batch_size, self.seq_len))
            if not shuffle:
                numbers = np.asarray(
                    [idx * self.batch_size + x + 1 for x in range(batch_size)])
                self.assertTrue(np.all(batch.numbers == numbers[:,
                                                                np.newaxis]))

        # check laziness
        if parallelize_processing:
            if lazy_mode == 'none':
                self.assertEqual(len(data._processed_cache), self.size)
            else:
                self.assertEqual(len(data._processed_cache), 0)
                if not support_random_access:
                    if lazy_mode == 'process':
                        self.assertEqual(len(data._cached_source._cache),
                                         self.size)
                    else:
                        self.assertEqual(len(data._cached_source._cache), 0)

        # first epoch
        cnt = 0
        for idx, batch in enumerate(iterator):
            check_batch(idx, batch)
            cnt += 1
        self.assertEqual(cnt, total_batches)

        # check cache
        if parallelize_processing:
            if cache_mode == 'none':
                self.assertEqual(len(data._processed_cache), 0)
            elif cache_mode == 'loaded':
                self.assertEqual(len(data._processed_cache), 0)
            else:
                self.assertEqual(len(data._processed_cache), self.size)
            if lazy_mode != 'none' and not support_random_access:
                if cache_mode == 'none':
                    self.assertEqual(len(data._cached_source._cache), 0)
                elif cache_mode == 'loaded':
                    self.assertEqual(len(data._cached_source._cache),
                                     self.size)
                else:
                    self.assertEqual(len(data._cached_source._cache), 0)

        # second epoch
        cnt = 0
        for idx, batch in enumerate(iterator):
            check_batch(idx, batch)
            cnt += 1
        self.assertEqual(cnt, total_batches)

        # check again
        if parallelize_processing:
            if cache_mode == 'none':
                self.assertEqual(len(data._processed_cache), 0)
            elif cache_mode == 'loaded':
                self.assertEqual(len(data._processed_cache), 0)
            else:
                self.assertEqual(len(data._processed_cache), self.size)
            if lazy_mode != 'none' and not support_random_access:
                if cache_mode == 'none':
                    self.assertEqual(len(data._cached_source._cache), 0)
                elif cache_mode == 'loaded':
                    self.assertEqual(len(data._cached_source._cache),
                                     self.size)
                else:
                    self.assertEqual(len(data._cached_source._cache), 0)
コード例 #14
0
    def test_iterator_multi_datasets(self):
        r"""Tests iterating over multiple datasets.
        """
        train = MonoTextData(self._train_hparams)
        test = MonoTextData(self._test_hparams)
        train_batch_size = self._train_hparams["batch_size"]
        test_batch_size = self._test_hparams["batch_size"]
        data_iterator = DataIterator({"train": train, "test": test})
        data_iterator.switch_to_dataset(dataset_name="train")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            # numbers: 1 - 2000, first 4 vocab entries are special tokens
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        data_iterator.switch_to_dataset(dataset_name="test")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test `get_iterator` interface
        for idx, val in enumerate(data_iterator.get_iterator('train')):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test exception for invalid dataset name
        with self.assertRaises(ValueError) as context:
            data_iterator.switch_to_dataset('val')
        self.assertTrue('not found' in str(context.exception))
コード例 #15
0
            len(examples),
            img_tensor=img_tensor,
            label=label,
            token_tensor=token_tensor,
            stop_prob=stop_prob,
        )

    @staticmethod
    def default_hparams():
        r"""Returns a dictionary of hyperparameters with default values.

        Returns: (Dict) default hyperparameters
        """
        hparams = DatasetBase.default_hparams()
        hparams.update({
            "datasource": None,
        })
        return hparams


if __name__ == "__main__":
    dataset_hparams = config.dataset
    dataset = IU_XRay_Dataset(dataset_hparams["train"])
    # Dataloader
    dataset.to(torch.device('cpu'))
    train_loader = DataIterator(dataset)

    for batch in train_loader:
        print(batch)
        break