def test_bucketing(self): r"""Tests bucketing. """ hparams = copy.deepcopy(self._hparams) hparams.update({ "bucket_boundaries": [7], "bucket_batch_sizes": [6, 4] }) text_data = MonoTextData(hparams) iterator = DataIterator(text_data) hparams.update({ "bucket_boundaries": [7], "bucket_batch_sizes": [7, 7], "allow_smaller_final_batch": False }) text_data_1 = MonoTextData(hparams) iterator_1 = DataIterator(text_data_1) for data_batch, data_batch_1 in zip(iterator, iterator_1): length = data_batch['length'][0] if length < 7: last_batch_size = hparams['num_epochs'] % 6 self.assertTrue( len(data_batch['text']) == 6 or len(data_batch['text']) == last_batch_size) else: last_batch_size = hparams['num_epochs'] % 4 self.assertTrue( len(data_batch['text']) == 4 or len(data_batch['text']) == last_batch_size) self.assertEqual(len(data_batch_1['text']), 7)
def _run_and_test(self, hparams): # Construct database text_data = MonoTextData(hparams) self.assertEqual( text_data.vocab.size, self._vocab_size + len(text_data.vocab.special_tokens)) iterator = DataIterator(text_data) for data_batch in iterator: # Run the logics self.assertEqual(set(data_batch.keys()), set(text_data.list_items())) # Test utterance count utt_ind = np.sum(data_batch["text_ids"], 2) != 0 utt_cnt = np.sum(utt_ind, 1) self.assertListEqual( data_batch[text_data.utterance_cnt_name].tolist(), utt_cnt.tolist()) if text_data.hparams.dataset.pad_to_max_seq_length: max_l = text_data.hparams.dataset.max_seq_length max_l += text_data._decoder.added_length for x in data_batch['text']: for xx in x: self.assertEqual(len(xx), max_l) for x in data_batch['text_ids']: for xx in x: self.assertEqual(len(xx), max_l)
def _run_and_test(self, hparams, test_batch_size=False, test_transform=False): # Construct database text_data = MonoTextData(hparams) self.assertEqual( text_data.vocab.size, self._vocab_size + len(text_data.vocab.special_tokens)) iterator = DataIterator(text_data) for data_batch in iterator: self.assertEqual(set(data_batch.keys()), set(text_data.list_items())) if test_batch_size: self.assertEqual(len(data_batch['text']), hparams['batch_size']) if test_transform: for i in range(len(data_batch['text'])): text_ = data_batch['text'][i] self.assertTrue(text_ in self.upper_cased_text) max_seq_length = text_data.hparams.dataset.max_seq_length mode = text_data.hparams.dataset.length_filter_mode max_l = max_seq_length if max_seq_length is not None: if text_data.hparams.dataset.eos_token != '': max_l += 1 if text_data.hparams.dataset.bos_token != '': max_l += 1 if max_seq_length == 6: for length in data_batch['length']: self.assertLessEqual(length, max_l) if mode == "discard": for length in data_batch['length']: self.assertEqual(length, 5) elif mode == "truncate": num_length_6 = 0 for length in data_batch['length']: num_length_6 += int(length == 6) self.assertGreater(num_length_6, 0) else: raise ValueError("Unknown mode: %s" % mode) if text_data.hparams.dataset.pad_to_max_seq_length: for x in data_batch['text']: self.assertEqual(len(x), max_l) for x in data_batch['text_ids']: self.assertEqual(len(x), max_l)
def test_train_test_data_iterator(self): r"""Tests :class:`texar.torch.data.TrainTestDataIterator` """ train = MonoTextData(self._train_hparams) test = MonoTextData(self._test_hparams) train_batch_size = self._train_hparams["batch_size"] test_batch_size = self._test_hparams["batch_size"] data_iterator = TrainTestDataIterator(train=train, test=test) data_iterator.switch_to_train_data() iterator = data_iterator.get_iterator() for idx, val in enumerate(iterator): self.assertEqual(len(val), train_batch_size) number = idx * train_batch_size + 1 self.assertEqual(val.text[0], [str(number)]) # numbers: 1 - 2000, first 4 vocab entries are special tokens self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) data_iterator.switch_to_test_data() iterator = data_iterator.get_iterator() for idx, val in enumerate(iterator): self.assertEqual(len(val), test_batch_size) number = idx * test_batch_size + 1001 self.assertEqual(val.text[0], [str(number)]) self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) # test `get_*_iterator` interface for idx, val in enumerate(data_iterator.get_test_iterator()): self.assertEqual(len(val), test_batch_size) number = idx * test_batch_size + 1001 self.assertEqual(val.text[0], [str(number)]) self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) # test exception for invalid dataset name with self.assertRaises(ValueError) as context: data_iterator.switch_to_val_data()
def test_iterator_multi_datasets(self): r"""Tests iterating over multiple datasets. """ train = MonoTextData(self._train_hparams) test = MonoTextData(self._test_hparams) train_batch_size = self._train_hparams["batch_size"] test_batch_size = self._test_hparams["batch_size"] data_iterator = DataIterator({"train": train, "test": test}) data_iterator.switch_to_dataset(dataset_name="train") iterator = data_iterator.get_iterator() for idx, val in enumerate(iterator): self.assertEqual(len(val), train_batch_size) number = idx * train_batch_size + 1 self.assertEqual(val.text[0], [str(number)]) # numbers: 1 - 2000, first 4 vocab entries are special tokens self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) data_iterator.switch_to_dataset(dataset_name="test") iterator = data_iterator.get_iterator() for idx, val in enumerate(iterator): self.assertEqual(len(val), test_batch_size) number = idx * test_batch_size + 1001 self.assertEqual(val.text[0], [str(number)]) self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) # test `get_iterator` interface for idx, val in enumerate(data_iterator.get_iterator('train')): self.assertEqual(len(val), train_batch_size) number = idx * train_batch_size + 1 self.assertEqual(val.text[0], [str(number)]) self.assertEqual(val.text_ids[0], torch.tensor(number + 3)) # test exception for invalid dataset name with self.assertRaises(ValueError) as context: data_iterator.switch_to_dataset('val') self.assertTrue('not found' in str(context.exception))
def test_iterator_single_dataset(self): r"""Tests iterating over a single dataset. """ data = MonoTextData(self._test_hparams) data_iterator = DataIterator(data) data_iterator.switch_to_dataset(dataset_name="data") iterator = data_iterator.get_iterator() i = 1001 for idx, batch in enumerate(iterator): self.assertEqual(batch.batch_size, self._test_hparams['batch_size']) np.testing.assert_array_equal(batch['length'], [1, 1]) for example in batch['text']: self.assertEqual(example[0], str(i)) i += 1 self.assertEqual(i, 2001)
def test_auto_storage_moving(self): cuda_tensors = set() def move_tensor(tensor, device, non_blocking=False): if isinstance(device, torch.device) and device.type == "cuda": self.assertTrue(non_blocking) cuda_tensors.add(id(tensor)) return tensor device = torch.device("cuda:0") with patch.object(torch.Tensor, "to", move_tensor): train = MonoTextData(self._train_hparams, device=device) iterator = DataIterator(train) for batch in iterator: self.assertTrue(id(batch.text_ids) in cuda_tensors) self.assertTrue(id(batch.length) in cuda_tensors)
def test_list_items(self): r"""Tests the item names of the output data. """ text_data = MonoTextData(self._hparams) self.assertSetEqual(set(text_data.list_items()), {"text", "text_ids", "length"}) hparams = copy.deepcopy(self._hparams) hparams["dataset"]["data_name"] = "data" text_data = MonoTextData(hparams) self.assertSetEqual(set(text_data.list_items()), {"data_text", "data_text_ids", "data_length"})
def make_embedding(src_emb_hparams, src_token_to_id_map, tgt_emb_hparams=None, tgt_token_to_id_map=None, emb_init_share=False): r"""Optionally loads source and target embeddings from files (if provided), and returns respective :class:`texar.torch.data.Embedding` instances. """ src_embedding = MonoTextData.make_embedding(src_emb_hparams, src_token_to_id_map) if emb_init_share: tgt_embedding = src_embedding else: tgt_emb_file = tgt_emb_hparams["file"] tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams) return src_embedding, tgt_embedding
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) # Defaultizes hyperparameters of each dataset datasets_hparams = self._hparams.datasets defaultized_datasets_hparams = [] for hparams_i in datasets_hparams: data_type = hparams_i.get("data_type", None) defaultized_ds_hpms = HParams(hparams_i, _default_dataset_hparams(data_type)) defaultized_datasets_hparams.append(defaultized_ds_hpms) self._hparams.datasets = defaultized_datasets_hparams self._vocab = self.make_vocab(self._hparams.datasets) self._embedding = self.make_embedding(self._hparams.datasets, self._vocab) dummy_source = SequenceDataSource[Any]([]) name_prefix: List[str] = [] self._names: List[Dict[str, Any]] = [] sources: List[DataSource] = [] filters: List[Optional[Callable[[str], bool]]] = [] self._databases: List[DataBase] = [] for idx, hparams_i in enumerate(self._hparams.datasets): data_type = _DataType(hparams_i.data_type) source_i: DataSource if _is_text_data(data_type): source_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type, delimiter=hparams_i.delimiter) sources.append(source_i) if ((hparams_i.length_filter_mode == _LengthFilterMode.DISCARD.value) and hparams_i.max_seq_length is not None): def _get_filter(max_seq_length): return lambda x: len(x) <= max_seq_length filters.append(_get_filter(hparams_i.max_seq_length)) else: filters.append(None) self._names.append({ field: connect_name(hparams_i.data_name, field) for field in ["text", "text_ids", "length"] }) dataset_hparams = dict_fetch( hparams_i, MonoTextData.default_hparams()["dataset"]) dataset_hparams["data_name"] = None self._databases.append( MonoTextData(hparams={"dataset": dataset_hparams}, device=device, vocab=self._vocab[idx], embedding=self._embedding[idx], data_source=dummy_source)) elif _is_scalar_data(data_type): source_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type) sources.append(source_i) filters.append(None) self._names.append({"data": hparams_i.data_name}) dataset_hparams = dict_fetch( hparams_i, ScalarData.default_hparams()["dataset"]) dataset_hparams["data_name"] = "data" self._databases.append( ScalarData(hparams={"dataset": dataset_hparams}, device=device, data_source=dummy_source)) elif _is_record_data(data_type): source_i = PickleDataSource(file_paths=hparams_i.files) sources.append(source_i) self._names.append({ name: connect_name(hparams_i.data_name, name) for name in hparams_i.feature_original_types.keys() }) filters.append(None) dataset_hparams = dict_fetch( hparams_i, RecordData.default_hparams()["dataset"]) self._databases.append( RecordData(hparams={"dataset": dataset_hparams}, device=device, data_source=dummy_source)) else: raise ValueError(f"Unknown data type: {hparams_i.data_type}") # check for duplicate names for i in range(1, len(name_prefix)): if name_prefix[i] in name_prefix[:i - 1]: raise ValueError(f"Duplicate data name: {name_prefix[i]}") name_prefix.append(hparams_i["data_name"]) self._name_to_id = {v: k for k, v in enumerate(name_prefix)} data_source: DataSource = ZipDataSource(*sources) if any(filters): def filter_fn(data): return all( fn(data) for fn, data in zip(filters, data) if fn is not None) data_source = FilterDataSource(data_source, filter_fn=filter_fn) super().__init__(data_source, self._hparams, device)
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) src_hparams = self.hparams.source_dataset tgt_hparams = self.hparams.target_dataset # create vocabulary self._src_bos_token = src_hparams["bos_token"] self._src_eos_token = src_hparams["eos_token"] self._src_transforms = src_hparams["other_transformations"] self._src_vocab = Vocab(src_hparams.vocab_file, bos_token=src_hparams.bos_token, eos_token=src_hparams.eos_token) if tgt_hparams["processing_share"]: self._tgt_bos_token = src_hparams["bos_token"] self._tgt_eos_token = src_hparams["eos_token"] else: self._tgt_bos_token = tgt_hparams["bos_token"] self._tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(self._tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(self._tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == self._src_vocab.bos_token and \ tgt_eos_token == self._src_vocab.eos_token: self._tgt_vocab = self._src_vocab else: self._tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: self._tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) # create embeddings self._src_embedding = MonoTextData.make_embedding( src_hparams.embedding_init, self._src_vocab.token_to_id_map_py) if self._hparams.target_dataset.embedding_init_share: self._tgt_embedding = self._src_embedding else: tgt_emb_file = tgt_hparams.embedding_init["file"] self._tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": self._tgt_embedding = MonoTextData.make_embedding( self._tgt_vocab.token_to_id_map_py, tgt_hparams.embedding_init) # create data source self._src_delimiter = src_hparams.delimiter self._src_max_seq_length = src_hparams.max_seq_length self._src_length_filter_mode = _LengthFilterMode( src_hparams.length_filter_mode) self._src_pad_length = self._src_max_seq_length if self._src_pad_length is not None: self._src_pad_length += sum( int(x is not None and x != '') for x in [src_hparams.bos_token, src_hparams.eos_token]) src_data_source = TextLineDataSource( src_hparams.files, compression_type=src_hparams.compression_type) self._tgt_transforms = tgt_hparams["other_transformations"] self._tgt_delimiter = tgt_hparams.delimiter self._tgt_max_seq_length = tgt_hparams.max_seq_length self._tgt_length_filter_mode = _LengthFilterMode( tgt_hparams.length_filter_mode) self._tgt_pad_length = self._tgt_max_seq_length if self._tgt_pad_length is not None: self._tgt_pad_length += sum( int(x is not None and x != '') for x in [tgt_hparams.bos_token, tgt_hparams.eos_token]) tgt_data_source = TextLineDataSource( tgt_hparams.files, compression_type=tgt_hparams.compression_type) data_source: DataSource[Tuple[List[str], List[str]]] data_source = ZipDataSource( # type: ignore src_data_source, tgt_data_source) if ((self._src_length_filter_mode is _LengthFilterMode.DISCARD and self._src_max_seq_length is not None) or (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and self._tgt_length_filter_mode is not None)): max_source_length = self._src_max_seq_length or math.inf max_tgt_length = self._tgt_max_seq_length or math.inf def filter_fn(raw_example): return (len(raw_example[0]) <= max_source_length and len(raw_example[1]) <= max_tgt_length) data_source = FilterDataSource(data_source, filter_fn) super().__init__(data_source, hparams, device=device)
def __init__(self, hparams, device: Optional[torch.device] = None): print("Using local texar") self._hparams = HParams(hparams, self.default_hparams()) # Defaultizes hyperparameters of each dataset datasets_hparams = self._hparams.datasets defaultized_datasets_hparams = [] for hparams_i in datasets_hparams: data_type = hparams_i.get("data_type", None) #print("data_type:", data_type) defaultized_ds_hpms = HParams(hparams_i, _default_dataset_hparams(data_type)) defaultized_datasets_hparams.append(defaultized_ds_hpms) self._hparams.datasets = defaultized_datasets_hparams #print("will make_vocab") self._vocab = self.make_vocab(self._hparams.datasets) #print("will make_embedding") self._embedding = self.make_embedding(self._hparams.datasets, self._vocab) dummy_source = SequenceDataSource[Any]([]) name_prefix: List[str] = [] self._names: List[Dict[str, Any]] = [] sources: List[DataSource] = [] filters: List[Optional[Callable[[str], bool]]] = [] self._databases: List[DatasetBase] = [] for idx, hparams_i in enumerate(self._hparams.datasets): data_type = hparams_i.data_type source_i: DataSource if _is_text_data(data_type): #print("will TextLineDataSource") source_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type, delimiter=hparams_i.delimiter) sources.append(source_i) if ((hparams_i.length_filter_mode == _LengthFilterMode.DISCARD.value) and hparams_i.max_seq_length is not None): def _get_filter(max_seq_length): return lambda x: len(x) <= max_seq_length filters.append(_get_filter(hparams_i.max_seq_length)) else: filters.append(None) self._names.append({ field: connect_name(hparams_i.data_name, field) for field in ["text", "text_ids", "length"] }) dataset_hparams = dict_fetch( hparams_i, MonoTextData.default_hparams()["dataset"]) dataset_hparams["data_name"] = None self._databases.append( MonoTextData(hparams={"dataset": dataset_hparams}, device=device, vocab=self._vocab[idx], embedding=self._embedding[idx], data_source=dummy_source)) elif _is_scalar_data(data_type): source_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type) sources.append(source_i) filters.append(None) self._names.append({"data": hparams_i.data_name}) dataset_hparams = dict_fetch( hparams_i, ScalarData.default_hparams()["dataset"]) dataset_hparams["data_name"] = "data" self._databases.append( ScalarData(hparams={"dataset": dataset_hparams}, device=device, data_source=dummy_source)) elif _is_record_data(data_type): source_i = PickleDataSource(file_paths=hparams_i.files) sources.append(source_i) # TODO: Only check `feature_types` when we finally remove # `feature_original_types`. feature_types = (hparams_i.feature_types or hparams_i.feature_original_types) self._names.append({ name: connect_name(hparams_i.data_name, name) for name in feature_types.keys() }) filters.append(None) dataset_hparams = dict_fetch( hparams_i, RecordData.default_hparams()["dataset"]) self._databases.append( RecordData(hparams={"dataset": dataset_hparams}, device=device, data_source=dummy_source)) else: raise ValueError(f"Unknown data type: {hparams_i.data_type}") # check for duplicate names for i in range(1, len(name_prefix)): if name_prefix[i] in name_prefix[:i - 1]: raise ValueError(f"Duplicate data name: {name_prefix[i]}") name_prefix.append(hparams_i["data_name"]) self._name_to_id = {v: k for k, v in enumerate(name_prefix)} self._processed_cache = [] self._datafile_id = 0 # for training from multiple files self._index_at_beginning_of_this_dataset = 0 self._datafile_prefix = hparams_i.files #self._datafile_num = 33 # hparams_i.datafile_num #self._datafile_num = 64 # hparams_i.datafile_num #self._datafile_num = 3 # hparams_i.datafile_num #self._datafile_num = 16 # hparams_i.datafile_num #self._datafile_num = 26 # hparams_i.datafile_num self._datafile_num = 1 # hparams_i.datafile_num #self._datafile_num = 3 # hparams_i.datafile_num data_source: DataSource = ZipDataSource(*sources) if any(filters): def filter_fn(data): return all( fn(data) for fn, data in zip(filters, data) if fn is not None) data_source = FilterDataSource(data_source, filter_fn=filter_fn) #print("data init derive done") super(MultiAlignedData, self).__init__(data_source, self._hparams, device) #self._dataset_size = 3000000 #self._dataset_size = 6400000 #self._dataset_size = 16000000 #self._dataset_size = 3802215 #self._dataset_size = 1250000 #self._dataset_size = 3000 self._dataset_size = 834229