コード例 #1
0
    def test_bucketing(self):
        r"""Tests bucketing.
        """
        hparams = copy.deepcopy(self._hparams)
        hparams.update({
            "bucket_boundaries": [7],
            "bucket_batch_sizes": [6, 4]
        })

        text_data = MonoTextData(hparams)
        iterator = DataIterator(text_data)

        hparams.update({
            "bucket_boundaries": [7],
            "bucket_batch_sizes": [7, 7],
            "allow_smaller_final_batch": False
        })

        text_data_1 = MonoTextData(hparams)
        iterator_1 = DataIterator(text_data_1)

        for data_batch, data_batch_1 in zip(iterator, iterator_1):
            length = data_batch['length'][0]
            if length < 7:
                last_batch_size = hparams['num_epochs'] % 6
                self.assertTrue(
                    len(data_batch['text']) == 6
                    or len(data_batch['text']) == last_batch_size)
            else:
                last_batch_size = hparams['num_epochs'] % 4
                self.assertTrue(
                    len(data_batch['text']) == 4
                    or len(data_batch['text']) == last_batch_size)

            self.assertEqual(len(data_batch_1['text']), 7)
コード例 #2
0
    def _run_and_test(self, hparams):
        # Construct database
        text_data = MonoTextData(hparams)
        self.assertEqual(
            text_data.vocab.size,
            self._vocab_size + len(text_data.vocab.special_tokens))

        iterator = DataIterator(text_data)

        for data_batch in iterator:
            # Run the logics
            self.assertEqual(set(data_batch.keys()),
                             set(text_data.list_items()))

            # Test utterance count
            utt_ind = np.sum(data_batch["text_ids"], 2) != 0
            utt_cnt = np.sum(utt_ind, 1)
            self.assertListEqual(
                data_batch[text_data.utterance_cnt_name].tolist(),
                utt_cnt.tolist())

            if text_data.hparams.dataset.pad_to_max_seq_length:
                max_l = text_data.hparams.dataset.max_seq_length
                max_l += text_data._decoder.added_length
                for x in data_batch['text']:
                    for xx in x:
                        self.assertEqual(len(xx), max_l)
                for x in data_batch['text_ids']:
                    for xx in x:
                        self.assertEqual(len(xx), max_l)
コード例 #3
0
    def _run_and_test(self,
                      hparams,
                      test_batch_size=False,
                      test_transform=False):
        # Construct database
        text_data = MonoTextData(hparams)
        self.assertEqual(
            text_data.vocab.size,
            self._vocab_size + len(text_data.vocab.special_tokens))

        iterator = DataIterator(text_data)

        for data_batch in iterator:
            self.assertEqual(set(data_batch.keys()),
                             set(text_data.list_items()))

            if test_batch_size:
                self.assertEqual(len(data_batch['text']),
                                 hparams['batch_size'])

            if test_transform:
                for i in range(len(data_batch['text'])):
                    text_ = data_batch['text'][i]
                    self.assertTrue(text_ in self.upper_cased_text)

            max_seq_length = text_data.hparams.dataset.max_seq_length
            mode = text_data.hparams.dataset.length_filter_mode

            max_l = max_seq_length
            if max_seq_length is not None:
                if text_data.hparams.dataset.eos_token != '':
                    max_l += 1
                if text_data.hparams.dataset.bos_token != '':
                    max_l += 1

            if max_seq_length == 6:
                for length in data_batch['length']:
                    self.assertLessEqual(length, max_l)
                if mode == "discard":
                    for length in data_batch['length']:
                        self.assertEqual(length, 5)
                elif mode == "truncate":
                    num_length_6 = 0
                    for length in data_batch['length']:
                        num_length_6 += int(length == 6)
                    self.assertGreater(num_length_6, 0)
                else:
                    raise ValueError("Unknown mode: %s" % mode)

            if text_data.hparams.dataset.pad_to_max_seq_length:
                for x in data_batch['text']:
                    self.assertEqual(len(x), max_l)
                for x in data_batch['text_ids']:
                    self.assertEqual(len(x), max_l)
コード例 #4
0
    def test_train_test_data_iterator(self):
        r"""Tests :class:`texar.torch.data.TrainTestDataIterator`
        """
        train = MonoTextData(self._train_hparams)
        test = MonoTextData(self._test_hparams)
        train_batch_size = self._train_hparams["batch_size"]
        test_batch_size = self._test_hparams["batch_size"]

        data_iterator = TrainTestDataIterator(train=train, test=test)
        data_iterator.switch_to_train_data()
        iterator = data_iterator.get_iterator()

        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            # numbers: 1 - 2000, first 4 vocab entries are special tokens
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        data_iterator.switch_to_test_data()
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test `get_*_iterator` interface
        for idx, val in enumerate(data_iterator.get_test_iterator()):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test exception for invalid dataset name
        with self.assertRaises(ValueError) as context:
            data_iterator.switch_to_val_data()
コード例 #5
0
    def test_iterator_multi_datasets(self):
        r"""Tests iterating over multiple datasets.
        """
        train = MonoTextData(self._train_hparams)
        test = MonoTextData(self._test_hparams)
        train_batch_size = self._train_hparams["batch_size"]
        test_batch_size = self._test_hparams["batch_size"]
        data_iterator = DataIterator({"train": train, "test": test})
        data_iterator.switch_to_dataset(dataset_name="train")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            # numbers: 1 - 2000, first 4 vocab entries are special tokens
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        data_iterator.switch_to_dataset(dataset_name="test")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test `get_iterator` interface
        for idx, val in enumerate(data_iterator.get_iterator('train')):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test exception for invalid dataset name
        with self.assertRaises(ValueError) as context:
            data_iterator.switch_to_dataset('val')
        self.assertTrue('not found' in str(context.exception))
コード例 #6
0
 def test_iterator_single_dataset(self):
     r"""Tests iterating over a single dataset.
     """
     data = MonoTextData(self._test_hparams)
     data_iterator = DataIterator(data)
     data_iterator.switch_to_dataset(dataset_name="data")
     iterator = data_iterator.get_iterator()
     i = 1001
     for idx, batch in enumerate(iterator):
         self.assertEqual(batch.batch_size, self._test_hparams['batch_size'])
         np.testing.assert_array_equal(batch['length'], [1, 1])
         for example in batch['text']:
             self.assertEqual(example[0], str(i))
             i += 1
     self.assertEqual(i, 2001)
コード例 #7
0
    def test_auto_storage_moving(self):
        cuda_tensors = set()

        def move_tensor(tensor, device, non_blocking=False):
            if isinstance(device, torch.device) and device.type == "cuda":
                self.assertTrue(non_blocking)
                cuda_tensors.add(id(tensor))
            return tensor

        device = torch.device("cuda:0")

        with patch.object(torch.Tensor, "to", move_tensor):
            train = MonoTextData(self._train_hparams, device=device)
            iterator = DataIterator(train)
            for batch in iterator:
                self.assertTrue(id(batch.text_ids) in cuda_tensors)
                self.assertTrue(id(batch.length) in cuda_tensors)
コード例 #8
0
    def test_list_items(self):
        r"""Tests the item names of the output data.
        """
        text_data = MonoTextData(self._hparams)
        self.assertSetEqual(set(text_data.list_items()),
                            {"text", "text_ids", "length"})

        hparams = copy.deepcopy(self._hparams)
        hparams["dataset"]["data_name"] = "data"
        text_data = MonoTextData(hparams)
        self.assertSetEqual(set(text_data.list_items()),
                            {"data_text", "data_text_ids", "data_length"})
コード例 #9
0
    def make_embedding(src_emb_hparams,
                       src_token_to_id_map,
                       tgt_emb_hparams=None,
                       tgt_token_to_id_map=None,
                       emb_init_share=False):
        r"""Optionally loads source and target embeddings from files (if
        provided), and returns respective :class:`texar.torch.data.Embedding`
        instances.
        """
        src_embedding = MonoTextData.make_embedding(src_emb_hparams,
                                                    src_token_to_id_map)

        if emb_init_share:
            tgt_embedding = src_embedding
        else:
            tgt_emb_file = tgt_emb_hparams["file"]
            tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams)

        return src_embedding, tgt_embedding
コード例 #10
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        dummy_source = SequenceDataSource[Any]([])
        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        sources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            data_type = _DataType(hparams_i.data_type)
            source_i: DataSource

            if _is_text_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type,
                    delimiter=hparams_i.delimiter)
                sources.append(source_i)
                if ((hparams_i.length_filter_mode
                     == _LengthFilterMode.DISCARD.value)
                        and hparams_i.max_seq_length is not None):

                    def _get_filter(max_seq_length):
                        return lambda x: len(x) <= max_seq_length

                    filters.append(_get_filter(hparams_i.max_seq_length))
                else:
                    filters.append(None)

                self._names.append({
                    field: connect_name(hparams_i.data_name, field)
                    for field in ["text", "text_ids", "length"]
                })

                dataset_hparams = dict_fetch(
                    hparams_i,
                    MonoTextData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = None
                self._databases.append(
                    MonoTextData(hparams={"dataset": dataset_hparams},
                                 device=device,
                                 vocab=self._vocab[idx],
                                 embedding=self._embedding[idx],
                                 data_source=dummy_source))
            elif _is_scalar_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                sources.append(source_i)
                filters.append(None)
                self._names.append({"data": hparams_i.data_name})

                dataset_hparams = dict_fetch(
                    hparams_i,
                    ScalarData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = "data"
                self._databases.append(
                    ScalarData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            elif _is_record_data(data_type):
                source_i = PickleDataSource(file_paths=hparams_i.files)
                sources.append(source_i)
                self._names.append({
                    name: connect_name(hparams_i.data_name, name)
                    for name in hparams_i.feature_original_types.keys()
                })
                filters.append(None)

                dataset_hparams = dict_fetch(
                    hparams_i,
                    RecordData.default_hparams()["dataset"])
                self._databases.append(
                    RecordData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            else:
                raise ValueError(f"Unknown data type: {hparams_i.data_type}")

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError(f"Duplicate data name: {name_prefix[i]}")

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        data_source: DataSource = ZipDataSource(*sources)

        if any(filters):

            def filter_fn(data):
                return all(
                    fn(data) for fn, data in zip(filters, data)
                    if fn is not None)

            data_source = FilterDataSource(data_source, filter_fn=filter_fn)
        super().__init__(data_source, self._hparams, device)
コード例 #11
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(
                int(x is not None and x != '')
                for x in [src_hparams.bos_token, src_hparams.eos_token])

        src_data_source = TextLineDataSource(
            src_hparams.files, compression_type=src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(
                int(x is not None and x != '')
                for x in [tgt_hparams.bos_token, tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(
            tgt_hparams.files, compression_type=tgt_hparams.compression_type)

        data_source: DataSource[Tuple[List[str], List[str]]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if ((self._src_length_filter_mode is _LengthFilterMode.DISCARD
             and self._src_max_seq_length is not None)
                or (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD
                    and self._tgt_length_filter_mode is not None)):
            max_source_length = self._src_max_seq_length or math.inf
            max_tgt_length = self._tgt_max_seq_length or math.inf

            def filter_fn(raw_example):
                return (len(raw_example[0]) <= max_source_length
                        and len(raw_example[1]) <= max_tgt_length)

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)
    def __init__(self, hparams, device: Optional[torch.device] = None):
        print("Using local texar")
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            #print("data_type:", data_type)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        #print("will make_vocab")
        self._vocab = self.make_vocab(self._hparams.datasets)
        #print("will make_embedding")
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        dummy_source = SequenceDataSource[Any]([])
        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        sources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DatasetBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            data_type = hparams_i.data_type
            source_i: DataSource

            if _is_text_data(data_type):
                #print("will TextLineDataSource")
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type,
                    delimiter=hparams_i.delimiter)
                sources.append(source_i)
                if ((hparams_i.length_filter_mode
                     == _LengthFilterMode.DISCARD.value)
                        and hparams_i.max_seq_length is not None):

                    def _get_filter(max_seq_length):
                        return lambda x: len(x) <= max_seq_length

                    filters.append(_get_filter(hparams_i.max_seq_length))
                else:
                    filters.append(None)

                self._names.append({
                    field: connect_name(hparams_i.data_name, field)
                    for field in ["text", "text_ids", "length"]
                })

                dataset_hparams = dict_fetch(
                    hparams_i,
                    MonoTextData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = None
                self._databases.append(
                    MonoTextData(hparams={"dataset": dataset_hparams},
                                 device=device,
                                 vocab=self._vocab[idx],
                                 embedding=self._embedding[idx],
                                 data_source=dummy_source))
            elif _is_scalar_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                sources.append(source_i)
                filters.append(None)
                self._names.append({"data": hparams_i.data_name})

                dataset_hparams = dict_fetch(
                    hparams_i,
                    ScalarData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = "data"
                self._databases.append(
                    ScalarData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            elif _is_record_data(data_type):
                source_i = PickleDataSource(file_paths=hparams_i.files)
                sources.append(source_i)
                # TODO: Only check `feature_types` when we finally remove
                #   `feature_original_types`.
                feature_types = (hparams_i.feature_types
                                 or hparams_i.feature_original_types)
                self._names.append({
                    name: connect_name(hparams_i.data_name, name)
                    for name in feature_types.keys()
                })
                filters.append(None)

                dataset_hparams = dict_fetch(
                    hparams_i,
                    RecordData.default_hparams()["dataset"])
                self._databases.append(
                    RecordData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            else:
                raise ValueError(f"Unknown data type: {hparams_i.data_type}")

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError(f"Duplicate data name: {name_prefix[i]}")

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}
        self._processed_cache = []
        self._datafile_id = 0  # for training from multiple files
        self._index_at_beginning_of_this_dataset = 0
        self._datafile_prefix = hparams_i.files
        #self._datafile_num = 33 # hparams_i.datafile_num
        #self._datafile_num = 64 # hparams_i.datafile_num
        #self._datafile_num = 3 # hparams_i.datafile_num
        #self._datafile_num = 16 # hparams_i.datafile_num
        #self._datafile_num = 26 # hparams_i.datafile_num
        self._datafile_num = 1  # hparams_i.datafile_num
        #self._datafile_num = 3 # hparams_i.datafile_num

        data_source: DataSource = ZipDataSource(*sources)

        if any(filters):

            def filter_fn(data):
                return all(
                    fn(data) for fn, data in zip(filters, data)
                    if fn is not None)

            data_source = FilterDataSource(data_source, filter_fn=filter_fn)
        #print("data init derive done")
        super(MultiAlignedData, self).__init__(data_source, self._hparams,
                                               device)
        #self._dataset_size = 3000000
        #self._dataset_size = 6400000
        #self._dataset_size = 16000000
        #self._dataset_size = 3802215
        #self._dataset_size = 1250000
        #self._dataset_size = 3000
        self._dataset_size = 834229