def _construct(cls, hparams, device: Optional[torch.device] = None, vocab: Optional[Vocab] = None, embedding: Optional[Vocab] = None): mono_text_data = cls.__new__(cls) mono_text_data._hparams = HParams(hparams, mono_text_data.default_hparams()) if mono_text_data._hparams.dataset.variable_utterance: raise NotImplementedError dataset = mono_text_data._hparams.dataset mono_text_data._other_transforms = dataset.other_transformations # Create vocabulary if vocab is not None: mono_text_data._vocab = vocab mono_text_data._bos_token = vocab.bos_token mono_text_data._eos_token = vocab.eos_token else: mono_text_data._bos_token = dataset.bos_token mono_text_data._eos_token = dataset.eos_token bos = utils.default_str(mono_text_data._bos_token, SpecialTokens.BOS) eos = utils.default_str(mono_text_data._eos_token, SpecialTokens.EOS) mono_text_data._vocab = Vocab(dataset.vocab_file, bos_token=bos, eos_token=eos) # Create embedding if embedding is not None: mono_text_data._embedding = embedding else: mono_text_data._embedding = mono_text_data.make_embedding( dataset.embedding_init, mono_text_data._vocab.token_to_id_map_py) mono_text_data._delimiter = dataset.delimiter mono_text_data._max_seq_length = dataset.max_seq_length mono_text_data._length_filter_mode = _LengthFilterMode( mono_text_data._hparams.dataset.length_filter_mode) mono_text_data._pad_length = mono_text_data._max_seq_length if mono_text_data._pad_length is not None: mono_text_data._pad_length += sum( int(x != '') for x in [mono_text_data._bos_token, mono_text_data._eos_token]) data_source: SequenceDataSource[str] = SequenceDataSource([]) super(MonoTextData, mono_text_data).__init__(source=data_source, hparams=hparams, device=device) return mono_text_data
def __init__(self, size: int, lazy_strategy: str, cache_strategy: str, unknown_size: bool = False): data = list(range(size)) source: DataSource[int] if unknown_size: source = IterDataSource(data) else: source = SequenceDataSource(data) hparams = { 'lazy_strategy': lazy_strategy, 'cache_strategy': cache_strategy, } super().__init__(source, hparams)
def _construct(cls, hparams): record_data = cls.__new__(cls) record_data._hparams = HParams(hparams, record_data.default_hparams()) feature_types = record_data._hparams.dataset.feature_original_types record_data._features = _convert_feature_hparams(feature_types) convert_types = record_data._hparams.dataset.feature_convert_types record_data._convert_types = { key: get_numpy_dtype(value) for key, value in convert_types.items() } for key, dtype in record_data._convert_types.items(): record_data._features[key] = record_data._features[key].\ _replace(dtype=dtype) image_options = record_data._hparams.dataset.image_options if isinstance(image_options, HParams): image_options = [image_options] record_data._image_transforms = {} for options in image_options: key = options.get('image_feature_name') if key is None or key not in record_data._features: continue record_data._image_transforms[key] = _create_image_transform( options.get('resize_height'), options.get('resize_width'), options.get('resize_method') or 'bilinear') record_data._other_transforms = \ record_data._hparams.dataset.other_transformations data_name = record_data._hparams.dataset.data_name record_data._items = { key: connect_name(data_name, key) for key, _ in record_data._features.items() } data_source = SequenceDataSource([]) super(RecordData, record_data).__init__(data_source, hparams) return record_data
def _construct(cls, hparams, device: Optional[torch.device] = None): scalar_data = cls.__new__(cls) scalar_data._hparams = HParams(hparams, scalar_data.default_hparams()) scalar_data._other_transforms = \ scalar_data._hparams.dataset.other_transformations data_type = scalar_data._hparams.dataset["data_type"] if data_type == "int": scalar_data._typecast_func = int scalar_data._to_data_type = np.int32 elif data_type == "float": scalar_data._typecast_func = float scalar_data._to_data_type = np.float32 else: raise ValueError( "Incorrect 'data_type'. Currently 'int' and " "'float' are supported. Received {}".format(data_type)) data_source: SequenceDataSource[str] = SequenceDataSource([]) super(ScalarData, scalar_data).__init__(data_source, hparams, device=device) return scalar_data
def _test_modes_with_workers(self, lazy_mode: str, cache_mode: str, num_workers: int, parallelize_processing: bool = True, support_random_access: bool = False, shuffle: bool = False, **kwargs): hparams = { 'batch_size': self.batch_size, 'lazy_strategy': lazy_mode, 'cache_strategy': cache_mode, 'num_parallel_calls': num_workers, 'shuffle': shuffle, 'shuffle_buffer_size': self.size // 5, 'parallelize_processing': parallelize_processing, 'allow_smaller_final_batch': False, **kwargs, } numbers_data = [[x] * self.seq_len for x in range(self.size)] string_data = [ ' '.join(map(str, range(self.seq_len))) for _ in range(self.size) ] if not support_random_access: source = ZipDataSource( # type: ignore IterDataSource(numbers_data), SequenceDataSource(string_data)) else: source = ZipDataSource(SequenceDataSource(numbers_data), SequenceDataSource(string_data)) data = MockDataBase(source, hparams) # type: ignore iterator = DataIterator(data) if data._hparams.allow_smaller_final_batch: total_examples = self.size total_batches = (self.size + self.batch_size - 1) // self.batch_size else: total_examples = self.size // self.batch_size * self.batch_size total_batches = self.size // self.batch_size def check_batch(idx, batch): if idx == total_batches - 1: batch_size = (total_examples - 1) % self.batch_size + 1 else: batch_size = self.batch_size self.assertEqual(batch.numbers.shape, (batch_size, self.seq_len)) if not shuffle: numbers = np.asarray( [idx * self.batch_size + x + 1 for x in range(batch_size)]) self.assertTrue(np.all(batch.numbers == numbers[:, np.newaxis])) # check laziness if parallelize_processing: if lazy_mode == 'none': self.assertEqual(len(data._processed_cache), self.size) else: self.assertEqual(len(data._processed_cache), 0) if not support_random_access: if lazy_mode == 'process': self.assertEqual(len(data._cached_source._cache), self.size) else: self.assertEqual(len(data._cached_source._cache), 0) # first epoch cnt = 0 for idx, batch in enumerate(iterator): check_batch(idx, batch) cnt += 1 self.assertEqual(cnt, total_batches) # check cache if parallelize_processing: if cache_mode == 'none': self.assertEqual(len(data._processed_cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._processed_cache), 0) else: self.assertEqual(len(data._processed_cache), self.size) if lazy_mode != 'none' and not support_random_access: if cache_mode == 'none': self.assertEqual(len(data._cached_source._cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._cached_source._cache), self.size) else: self.assertEqual(len(data._cached_source._cache), 0) # second epoch cnt = 0 for idx, batch in enumerate(iterator): check_batch(idx, batch) cnt += 1 self.assertEqual(cnt, total_batches) # check again if parallelize_processing: if cache_mode == 'none': self.assertEqual(len(data._processed_cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._processed_cache), 0) else: self.assertEqual(len(data._processed_cache), self.size) if lazy_mode != 'none' and not support_random_access: if cache_mode == 'none': self.assertEqual(len(data._cached_source._cache), 0) elif cache_mode == 'loaded': self.assertEqual(len(data._cached_source._cache), self.size) else: self.assertEqual(len(data._cached_source._cache), 0)