コード例 #1
0
    def _construct(cls,
                   hparams,
                   device: Optional[torch.device] = None,
                   vocab: Optional[Vocab] = None,
                   embedding: Optional[Vocab] = None):
        mono_text_data = cls.__new__(cls)
        mono_text_data._hparams = HParams(hparams,
                                          mono_text_data.default_hparams())
        if mono_text_data._hparams.dataset.variable_utterance:
            raise NotImplementedError

        dataset = mono_text_data._hparams.dataset
        mono_text_data._other_transforms = dataset.other_transformations

        # Create vocabulary
        if vocab is not None:
            mono_text_data._vocab = vocab
            mono_text_data._bos_token = vocab.bos_token
            mono_text_data._eos_token = vocab.eos_token
        else:
            mono_text_data._bos_token = dataset.bos_token
            mono_text_data._eos_token = dataset.eos_token
            bos = utils.default_str(mono_text_data._bos_token,
                                    SpecialTokens.BOS)
            eos = utils.default_str(mono_text_data._eos_token,
                                    SpecialTokens.EOS)
            mono_text_data._vocab = Vocab(dataset.vocab_file,
                                          bos_token=bos,
                                          eos_token=eos)

        # Create embedding
        if embedding is not None:
            mono_text_data._embedding = embedding
        else:
            mono_text_data._embedding = mono_text_data.make_embedding(
                dataset.embedding_init,
                mono_text_data._vocab.token_to_id_map_py)

        mono_text_data._delimiter = dataset.delimiter
        mono_text_data._max_seq_length = dataset.max_seq_length
        mono_text_data._length_filter_mode = _LengthFilterMode(
            mono_text_data._hparams.dataset.length_filter_mode)
        mono_text_data._pad_length = mono_text_data._max_seq_length
        if mono_text_data._pad_length is not None:
            mono_text_data._pad_length += sum(
                int(x != '') for x in
                [mono_text_data._bos_token, mono_text_data._eos_token])

        data_source: SequenceDataSource[str] = SequenceDataSource([])
        super(MonoTextData, mono_text_data).__init__(source=data_source,
                                                     hparams=hparams,
                                                     device=device)

        return mono_text_data
コード例 #2
0
 def __init__(self, size: int, lazy_strategy: str,
              cache_strategy: str, unknown_size: bool = False):
     data = list(range(size))
     source: DataSource[int]
     if unknown_size:
         source = IterDataSource(data)
     else:
         source = SequenceDataSource(data)
     hparams = {
         'lazy_strategy': lazy_strategy,
         'cache_strategy': cache_strategy,
     }
     super().__init__(source, hparams)
コード例 #3
0
ファイル: record_data.py プロジェクト: haoransh/texar-pytorch
    def _construct(cls, hparams):
        record_data = cls.__new__(cls)
        record_data._hparams = HParams(hparams, record_data.default_hparams())

        feature_types = record_data._hparams.dataset.feature_original_types
        record_data._features = _convert_feature_hparams(feature_types)

        convert_types = record_data._hparams.dataset.feature_convert_types
        record_data._convert_types = {
            key: get_numpy_dtype(value)
            for key, value in convert_types.items()
        }
        for key, dtype in record_data._convert_types.items():
            record_data._features[key] = record_data._features[key].\
                _replace(dtype=dtype)

        image_options = record_data._hparams.dataset.image_options
        if isinstance(image_options, HParams):
            image_options = [image_options]
        record_data._image_transforms = {}
        for options in image_options:
            key = options.get('image_feature_name')
            if key is None or key not in record_data._features:
                continue
            record_data._image_transforms[key] = _create_image_transform(
                options.get('resize_height'), options.get('resize_width'),
                options.get('resize_method') or 'bilinear')

        record_data._other_transforms = \
            record_data._hparams.dataset.other_transformations

        data_name = record_data._hparams.dataset.data_name
        record_data._items = {
            key: connect_name(data_name, key)
            for key, _ in record_data._features.items()
        }

        data_source = SequenceDataSource([])

        super(RecordData, record_data).__init__(data_source, hparams)
        return record_data
コード例 #4
0
 def _construct(cls, hparams, device: Optional[torch.device] = None):
     scalar_data = cls.__new__(cls)
     scalar_data._hparams = HParams(hparams, scalar_data.default_hparams())
     scalar_data._other_transforms = \
         scalar_data._hparams.dataset.other_transformations
     data_type = scalar_data._hparams.dataset["data_type"]
     if data_type == "int":
         scalar_data._typecast_func = int
         scalar_data._to_data_type = np.int32
     elif data_type == "float":
         scalar_data._typecast_func = float
         scalar_data._to_data_type = np.float32
     else:
         raise ValueError(
             "Incorrect 'data_type'. Currently 'int' and "
             "'float' are supported. Received {}".format(data_type))
     data_source: SequenceDataSource[str] = SequenceDataSource([])
     super(ScalarData, scalar_data).__init__(data_source,
                                             hparams,
                                             device=device)
     return scalar_data
コード例 #5
0
    def _test_modes_with_workers(self,
                                 lazy_mode: str,
                                 cache_mode: str,
                                 num_workers: int,
                                 parallelize_processing: bool = True,
                                 support_random_access: bool = False,
                                 shuffle: bool = False,
                                 **kwargs):
        hparams = {
            'batch_size': self.batch_size,
            'lazy_strategy': lazy_mode,
            'cache_strategy': cache_mode,
            'num_parallel_calls': num_workers,
            'shuffle': shuffle,
            'shuffle_buffer_size': self.size // 5,
            'parallelize_processing': parallelize_processing,
            'allow_smaller_final_batch': False,
            **kwargs,
        }
        numbers_data = [[x] * self.seq_len for x in range(self.size)]
        string_data = [
            ' '.join(map(str, range(self.seq_len))) for _ in range(self.size)
        ]
        if not support_random_access:
            source = ZipDataSource(  # type: ignore
                IterDataSource(numbers_data), SequenceDataSource(string_data))
        else:
            source = ZipDataSource(SequenceDataSource(numbers_data),
                                   SequenceDataSource(string_data))
        data = MockDataBase(source, hparams)  # type: ignore
        iterator = DataIterator(data)

        if data._hparams.allow_smaller_final_batch:
            total_examples = self.size
            total_batches = (self.size + self.batch_size -
                             1) // self.batch_size
        else:
            total_examples = self.size // self.batch_size * self.batch_size
            total_batches = self.size // self.batch_size

        def check_batch(idx, batch):
            if idx == total_batches - 1:
                batch_size = (total_examples - 1) % self.batch_size + 1
            else:
                batch_size = self.batch_size
            self.assertEqual(batch.numbers.shape, (batch_size, self.seq_len))
            if not shuffle:
                numbers = np.asarray(
                    [idx * self.batch_size + x + 1 for x in range(batch_size)])
                self.assertTrue(np.all(batch.numbers == numbers[:,
                                                                np.newaxis]))

        # check laziness
        if parallelize_processing:
            if lazy_mode == 'none':
                self.assertEqual(len(data._processed_cache), self.size)
            else:
                self.assertEqual(len(data._processed_cache), 0)
                if not support_random_access:
                    if lazy_mode == 'process':
                        self.assertEqual(len(data._cached_source._cache),
                                         self.size)
                    else:
                        self.assertEqual(len(data._cached_source._cache), 0)

        # first epoch
        cnt = 0
        for idx, batch in enumerate(iterator):
            check_batch(idx, batch)
            cnt += 1
        self.assertEqual(cnt, total_batches)

        # check cache
        if parallelize_processing:
            if cache_mode == 'none':
                self.assertEqual(len(data._processed_cache), 0)
            elif cache_mode == 'loaded':
                self.assertEqual(len(data._processed_cache), 0)
            else:
                self.assertEqual(len(data._processed_cache), self.size)
            if lazy_mode != 'none' and not support_random_access:
                if cache_mode == 'none':
                    self.assertEqual(len(data._cached_source._cache), 0)
                elif cache_mode == 'loaded':
                    self.assertEqual(len(data._cached_source._cache),
                                     self.size)
                else:
                    self.assertEqual(len(data._cached_source._cache), 0)

        # second epoch
        cnt = 0
        for idx, batch in enumerate(iterator):
            check_batch(idx, batch)
            cnt += 1
        self.assertEqual(cnt, total_batches)

        # check again
        if parallelize_processing:
            if cache_mode == 'none':
                self.assertEqual(len(data._processed_cache), 0)
            elif cache_mode == 'loaded':
                self.assertEqual(len(data._processed_cache), 0)
            else:
                self.assertEqual(len(data._processed_cache), self.size)
            if lazy_mode != 'none' and not support_random_access:
                if cache_mode == 'none':
                    self.assertEqual(len(data._cached_source._cache), 0)
                elif cache_mode == 'loaded':
                    self.assertEqual(len(data._cached_source._cache),
                                     self.size)
                else:
                    self.assertEqual(len(data._cached_source._cache), 0)