Exemplo n.º 1
0
    def _make_processor(src_hparams, tgt_hparams, data_spec, name_prefix):
        # Create source data decoder
        data_spec_i = data_spec.get_ith_data_spec(0)
        src_decoder, src_trans, data_spec_i = MonoTextData._make_processor(
            src_hparams, data_spec_i, chained=False)
        data_spec.set_ith_data_spec(0, data_spec_i, 2)

        # Create target data decoder
        tgt_proc_hparams = tgt_hparams
        if tgt_hparams["processing_share"]:
            tgt_proc_hparams = copy.copy(src_hparams)
            try:
                tgt_proc_hparams["variable_utterance"] = \
                        tgt_hparams["variable_utterance"]
            except TypeError:
                tgt_proc_hparams.variable_utterance = \
                        tgt_hparams["variable_utterance"]
        data_spec_i = data_spec.get_ith_data_spec(1)
        tgt_decoder, tgt_trans, data_spec_i = MonoTextData._make_processor(
            tgt_proc_hparams, data_spec_i, chained=False)
        data_spec.set_ith_data_spec(1, data_spec_i, 2)

        tran_fn = dsutils.make_combined_transformation(
            [[src_decoder] + src_trans, [tgt_decoder] + tgt_trans],
            name_prefix=name_prefix)

        data_spec.add_spec(name_prefix=name_prefix)

        return tran_fn, data_spec
Exemplo n.º 2
0
 def _make_length_filter(src_hparams, tgt_hparams, src_length_name,
                         tgt_length_name, src_decoder, tgt_decoder):
     src_filter_fn = MonoTextData._make_length_filter(
         src_hparams, src_length_name, src_decoder)
     tgt_filter_fn = MonoTextData._make_length_filter(
         tgt_hparams, tgt_length_name, tgt_decoder)
     combined_filter_fn = dsutils._make_combined_filter_fn(
         [src_filter_fn, tgt_filter_fn])
     return combined_filter_fn
Exemplo n.º 3
0
    def _make_data(self):
        dataset_hparams = self._hparams.dataset

        # Create and shuffle dataset
        dataset = MonoTextData._make_mono_text_dataset(dataset_hparams)
        dataset, dataset_size = self._shuffle_dataset(
            dataset, self._hparams, self._hparams.dataset.files)
        self._dataset_size = dataset_size

        # Processing
        # pylint: disable=protected-access
        data_spec = dsutils._DataSpec(dataset=dataset,
                                      dataset_size=self._dataset_size)
        dataset, data_spec = self._process_dataset(dataset, self._hparams,
                                                   data_spec)
        self._data_spec = data_spec
        self._decoder = data_spec.decoder # pylint: disable=no-member

        # Batching
        dataset = self._make_batch(dataset, self._hparams)

        # Prefetching
        if self._hparams.prefetch_buffer_size > 0:
            dataset = dataset.prefetch(self._hparams.prefetch_buffer_size)

        self._dataset = dataset
Exemplo n.º 4
0
    def _make_processor(dataset_hparams, data_spec, name_prefix):
        processors = []
        for i, hparams_i in enumerate(dataset_hparams):
            data_spec_i = data_spec.get_ith_data_spec(i)

            data_type = hparams_i["data_type"]
            if _is_text_data(data_type):
                tgt_proc_hparams = hparams_i
                proc_shr = hparams_i["processing_share_with"]
                if proc_shr is not None:
                    tgt_proc_hparams = copy.copy(dataset_hparams[proc_shr])
                    try:
                        tgt_proc_hparams["variable_utterance"] = \
                                hparams_i["variable_utterance"]
                    except TypeError:
                        tgt_proc_hparams.variable_utterance = \
                                hparams_i["variable_utterance"]

                processor, data_spec_i = MonoTextData._make_processor(
                    tgt_proc_hparams, data_spec_i)
            elif _is_scalar_data(data_type):
                processor, data_spec_i = ScalarData._make_processor(
                    hparams_i, data_spec_i, name_prefix='')
            else:
                raise ValueError("Unsupported data type: %s" % data_type)

            processors.append(processor)
            data_spec.set_ith_data_spec(i, data_spec_i, len(dataset_hparams))

        tran_fn = dsutils.make_combined_transformation(processors,
                                                       name_prefix=name_prefix)

        data_spec.add_spec(name_prefix=name_prefix)

        return tran_fn, data_spec
Exemplo n.º 5
0
    def _make_padded_shapes(self, dataset, src_decoder, tgt_decoder):
        src_text_and_id_shapes = {}
        if self._hparams.source_dataset.pad_to_max_seq_length:
            src_text_and_id_shapes = \
                    MonoTextData._make_padded_text_and_id_shapes(
                        dataset, self._hparams.source_dataset, src_decoder,
                        self.source_text_name, self.source_text_id_name)

        tgt_text_and_id_shapes = {}
        if self._hparams.target_dataset.pad_to_max_seq_length:
            tgt_text_and_id_shapes = \
                    MonoTextData._make_padded_text_and_id_shapes(
                        dataset, self._hparams.target_dataset, tgt_decoder,
                        self.target_text_name, self.target_text_id_name)

        padded_shapes = dataset.output_shapes
        padded_shapes.update(src_text_and_id_shapes)
        padded_shapes.update(tgt_text_and_id_shapes)

        return padded_shapes
Exemplo n.º 6
0
 def _make_length_filter(dataset_hparams, length_name, decoder):
     filter_fns = []
     for i, hpms in enumerate(dataset_hparams):
         if not _is_text_data(hpms["data_type"]):
             filter_fn = None
         else:
             filter_fn = MonoTextData._make_length_filter(
                 hpms, length_name[i], decoder[i])
         filter_fns.append(filter_fn)
     combined_filter_fn = dsutils._make_combined_filter_fn(filter_fns)
     return combined_filter_fn
Exemplo n.º 7
0
    def test_train_test_data_iterator(self):
        r"""Tests :class:`texar.data.TrainTestDataIterator`
        """
        train = MonoTextData(self._train_hparams)
        test = MonoTextData(self._test_hparams)
        train_batch_size = self._train_hparams["batch_size"]
        test_batch_size = self._test_hparams["batch_size"]

        data_iterator = TrainTestDataIterator(train=train, test=test)
        data_iterator.switch_to_train_data()
        iterator = data_iterator.get_iterator()

        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            # numbers: 1 - 2000, first 4 vocab entries are special tokens
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        data_iterator.switch_to_test_data()
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test `get_*_iterator` interface
        for idx, val in enumerate(data_iterator.get_test_iterator()):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test exception for invalid dataset name
        with self.assertRaises(ValueError) as context:
            data_iterator.switch_to_val_data()
        self.assertTrue('Val data not provided' in str(context.exception))
Exemplo n.º 8
0
    def _make_padded_shapes(self, dataset, decoders):
        padded_shapes = dataset.output_shapes
        for i, hparams_i in enumerate(self._hparams.datasets):
            if not _is_text_data(hparams_i["data_type"]):
                continue
            if not hparams_i["pad_to_max_seq_length"]:
                continue
            text_and_id_shapes = MonoTextData._make_padded_text_and_id_shapes(
                dataset, hparams_i, decoders[i], self.text_name(i),
                self.text_id_name(i))

            padded_shapes.update(text_and_id_shapes)

        return padded_shapes
Exemplo n.º 9
0
    def test_iterator_multi_datasets(self):
        r"""Tests iterating over multiple datasets.
        """
        train = MonoTextData(self._train_hparams)
        test = MonoTextData(self._test_hparams)
        train_batch_size = self._train_hparams["batch_size"]
        test_batch_size = self._test_hparams["batch_size"]
        data_iterator = DataIterator({"train": train, "test": test})
        data_iterator.switch_to_dataset(dataset_name="train")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            # numbers: 1 - 2000, first 4 vocab entries are special tokens
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        data_iterator.switch_to_dataset(dataset_name="test")
        iterator = data_iterator.get_iterator()
        for idx, val in enumerate(iterator):
            self.assertEqual(len(val), test_batch_size)
            number = idx * test_batch_size + 1001
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test `get_iterator` interface
        for idx, val in enumerate(data_iterator.get_iterator('train')):
            self.assertEqual(len(val), train_batch_size)
            number = idx * train_batch_size + 1
            self.assertEqual(val.text[0], [str(number)])
            self.assertEqual(val.text_ids[0], torch.tensor(number + 3))

        # test exception for invalid dataset name
        with self.assertRaises(ValueError) as context:
            data_iterator.switch_to_dataset('val')
        self.assertTrue('not found' in str(context.exception))
 def test_iterator_single_dataset(self):
     r"""Tests iterating over a single dataset.
     """
     data = MonoTextData(self._test_hparams)
     data_iterator = DataIterator(data)
     data_iterator.switch_to_dataset(dataset_name="data")
     iterator = data_iterator.get_iterator()
     i = 1001
     for idx, batch in enumerate(iterator):
         self.assertEqual(batch.batch_size, self._test_hparams['batch_size'])
         np.testing.assert_array_equal(batch['length'], [1, 1])
         for example in batch['text']:
             self.assertEqual(example[0], str(i))
             i += 1
     self.assertEqual(i, 2001)
Exemplo n.º 11
0
    def make_embedding(src_emb_hparams, src_token_to_id_map,
                       tgt_emb_hparams=None, tgt_token_to_id_map=None,
                       emb_init_share=False):
        r"""Optionally loads source and target embeddings from files
        (if provided), and returns respective :class:`texar.data.Embedding`
        instances.
        """
        src_embedding = MonoTextData.make_embedding(src_emb_hparams,
                                                    src_token_to_id_map)

        if emb_init_share:
            tgt_embedding = src_embedding
        else:
            tgt_emb_file = tgt_emb_hparams["file"]
            tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams)

        return src_embedding, tgt_embedding
Exemplo n.º 12
0
    def _make_processor(dataset_hparams, data_spec, chained=True,
                        name_prefix=None):
        # Create data decoder
        decoder = ScalarDataDecoder(
            ScalarData._get_dtype(dataset_hparams["data_type"]),
            data_name=name_prefix)
        # Create other transformations
        data_spec.add_spec(decoder=decoder)
        # pylint: disable=protected-access
        other_trans = MonoTextData._make_other_transformations(
            dataset_hparams["other_transformations"], data_spec)

        data_spec.add_spec(name_prefix=name_prefix)

        if chained:
            chained_tran = dsutils.make_chained_transformation(
                [decoder] + other_trans)
            return chained_tran, data_spec
        else:
            return decoder, other_trans, data_spec
Exemplo n.º 13
0
    def _make_processor(dataset_hparams, data_spec, chained=True,
                        name_prefix=None):
        # Create data decoder
        decoder = TFRecordDataDecoder(
            feature_original_types=dataset_hparams.feature_original_types,
            feature_convert_types=dataset_hparams.feature_convert_types,
            image_options=dataset_hparams.image_options)
        # Create other transformations
        data_spec.add_spec(decoder=decoder)
        # pylint: disable=protected-access
        other_trans = MonoTextData._make_other_transformations(
            dataset_hparams["other_transformations"], data_spec)

        data_spec.add_spec(name_prefix=name_prefix)

        if chained:
            chained_tran = dsutils.make_chained_transformation(
                [decoder] + other_trans)
            return chained_tran, data_spec
        else:
            return decoder, other_trans, data_spec
Exemplo n.º 14
0
    def make_vocab(src_hparams, tgt_hparams):
        """Reads vocab files and returns source vocab and target vocab.

        Args:
            src_hparams (dict or HParams): Hyperparameters of source dataset.
            tgt_hparams (dict or HParams): Hyperparameters of target dataset.

        Returns:
            A pair of :class:`texar.data.Vocab` instances. The two instances
            may be the same objects if source and target vocabs are shared
            and have the same other configs.
        """
        src_vocab = MonoTextData.make_vocab(src_hparams)

        if tgt_hparams["processing_share"]:
            tgt_bos_token = src_hparams["bos_token"]
            tgt_eos_token = src_hparams["eos_token"]
        else:
            tgt_bos_token = tgt_hparams["bos_token"]
            tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == src_vocab.bos_token and \
                    tgt_eos_token == src_vocab.eos_token:
                tgt_vocab = src_vocab
            else:
                tgt_vocab = Vocab(src_hparams["vocab_file"],
                                  bos_token=tgt_bos_token,
                                  eos_token=tgt_eos_token)
        else:
            tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                              bos_token=tgt_bos_token,
                              eos_token=tgt_eos_token)

        return src_vocab, tgt_vocab
Exemplo n.º 15
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        datasources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            dtype = hparams_i.data_type
            datasource_i: DataSource

            if _is_text_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                if (hparams_i["length_filter_mode"] ==
                    _LengthFilterMode.DISCARD.value) and \
                        (hparams_i["max_seq_length"] is not None):

                    def _get_filter(delimiter, max_seq_length):
                        return lambda x: len(x.split(delimiter)) <= \
                               max_seq_length

                    filters.append(
                        _get_filter(hparams_i["delimiter"],
                                    hparams_i["max_seq_length"]))
                else:
                    filters.append(None)

                self._names.append({
                    "text":
                    connect_name(hparams_i["data_name"], "text"),
                    "text_ids":
                    connect_name(hparams_i["data_name"], "text_ids"),
                    "length":
                    connect_name(hparams_i["data_name"], "length")
                })

                text_hparams = MonoTextData.default_hparams()
                for key in text_hparams["dataset"].keys():
                    if key in hparams_i:
                        text_hparams["dataset"][key] = hparams_i[key]
                # handle prepend logic in MultiAlignedData collate function
                text_hparams["dataset"]["data_name"] = None

                self._databases.append(
                    MonoTextData._construct(hparams=text_hparams,
                                            device=device,
                                            vocab=self._vocab[idx],
                                            embedding=self._embedding[idx]))
            elif _is_scalar_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                filters.append(None)
                self._names.append({"label": hparams_i["data_name"]})
                scalar_hparams = ScalarData.default_hparams()
                scalar_hparams.update({"dataset": hparams_i.todict()})
                self._databases.append(
                    ScalarData._construct(hparams=scalar_hparams,
                                          device=device))
            elif _is_record_data(dtype):
                datasource_i = PickleDataSource(file_paths=hparams_i.files)
                datasources.append(datasource_i)
                self._names.append({
                    feature_type: connect_name(hparams_i["data_name"],
                                               feature_type)
                    for feature_type in
                    hparams_i.feature_original_types.keys()
                })
                filters.append(None)
                record_hparams = RecordData.default_hparams()
                for key in record_hparams["dataset"].keys():
                    if key in hparams_i:
                        record_hparams["dataset"][key] = hparams_i[key]
                self._databases.append(
                    RecordData._construct(hparams=record_hparams))
            else:
                raise ValueError("Unknown data type: %s" % hparams_i.data_type)

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError("Data name duplicated: %s" %
                                     name_prefix[i])

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        datasource: DataSource
        datasource = ZipDataSource(*datasources)

        if any(filters):

            def filter_fn(data):
                return all(filters[idx](data_i)
                           for idx, data_i in enumerate(data)
                           if filters[idx] is not None)

            datasource = FilterDataSource(datasource, filter_fn=filter_fn)
        super().__init__(datasource, self._hparams, device)
Exemplo n.º 16
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(int(x is not None and x != '')
                                        for x in [src_hparams.bos_token,
                                                  src_hparams.eos_token])

        src_data_source = TextLineDataSource(src_hparams.files,
                                             compression_type=
                                             src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(int(x is not None and x != '')
                                        for x in [tgt_hparams.bos_token,
                                                  tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(tgt_hparams.files,
                                             compression_type=
                                             tgt_hparams.compression_type)

        data_source: DataSource[Tuple[str, str]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and
            self._src_max_seq_length is not None) or \
                (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and
                 self._tgt_length_filter_mode is not None):
            max_source_length = self._src_max_seq_length if \
                self._src_max_seq_length is not None else np.inf
            max_tgt_length = self._tgt_max_seq_length if \
                self._tgt_max_seq_length is not None else np.inf

            def filter_fn(raw_example):
                return len(raw_example[0].split(self._src_delimiter)) \
                       <= max_source_length and \
                       len(raw_example[1].split(self._tgt_delimiter)) \
                       <= max_tgt_length

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)