Example #1
0
class ParallelData(DataBase[RawExample, Example]):
    def __init__(self,
                 source: DataSource[RawExample],
                 src_vocab_path: str,
                 tgt_vocab_path: str,
                 hparams: AnyDict,
                 device: Optional[torch.device] = None):
        # hparams.update(parallelize_processing=False)
        self.src_vocab = Vocab(src_vocab_path)
        self.tgt_vocab = Vocab(tgt_vocab_path)
        self.device = device
        super().__init__(source, hparams=hparams)

    def process(self, raw_example: RawExample) -> Example:
        src, tgt = raw_example.strip().split('\t')
        src = self.src_vocab.map_tokens_to_ids_py(src.split())
        tgt = self.tgt_vocab.map_tokens_to_ids_py(tgt.split())
        return src, tgt

    def collate(self, examples: List[Example]) -> Batch:
        src_pad_length = max(len(src) for src, _ in examples)
        tgt_pad_length = max(len(tgt) for _, tgt in examples)
        batch_size = len(examples)
        src_indices = np.zeros((batch_size, src_pad_length), dtype=np.int64)
        tgt_indices = np.zeros((batch_size, tgt_pad_length), dtype=np.int64)
        for b_idx, (src, tgt) in enumerate(examples):
            src_indices[b_idx, :len(src)] = src
            tgt_indices[b_idx, :len(tgt)] = tgt
        src_indices = torch.from_numpy(src_indices).to(device=self.device)
        tgt_indices = torch.from_numpy(tgt_indices).to(device=self.device)
        return Batch(batch_size, src=src_indices, tgt=tgt_indices)
Example #2
0
 def __init__(self,
              source: DataSource[RawExample],
              src_vocab_path: str,
              tgt_vocab_path: str,
              hparams: AnyDict,
              device: Optional[torch.device] = None):
     # hparams.update(parallelize_processing=False)
     self.src_vocab = Vocab(src_vocab_path)
     self.tgt_vocab = Vocab(tgt_vocab_path)
     self.device = device
     super().__init__(source, hparams=hparams)
    def make_vocab(hparams):
        r"""Makes a list of vocabs based on the hyperparameters.

        Args:
            hparams (list): A list of dataset hyperparameters.

        Returns:
            A list of :class:`texar.data.Vocab` instances. Some instances
            may be the same objects if they are set to be shared and have
            the same other configurations.
        """
        if not isinstance(hparams, (list, tuple)):
            hparams = [hparams]

        vocabs = []
        for i, hparams_i in enumerate(hparams):
            if not _is_text_data(hparams_i["data_type"]):
                vocabs.append(None)
                continue

            proc_share = hparams_i["processing_share_with"]
            if proc_share is not None:
                bos_token = hparams[proc_share]["bos_token"]
                eos_token = hparams[proc_share]["eos_token"]
            else:
                bos_token = hparams_i["bos_token"]
                eos_token = hparams_i["eos_token"]
            bos_token = utils.default_str(
                bos_token, SpecialTokens.BOS)
            eos_token = utils.default_str(
                eos_token, SpecialTokens.EOS)

            vocab_share = hparams_i["vocab_share_with"]
            if vocab_share is not None:
                if vocab_share >= i:
                    MultiAlignedData._raise_sharing_error(
                        i, vocab_share, "vocab_share_with")
                if vocabs[vocab_share] is None:
                    raise ValueError("Cannot share vocab with dataset %d which "
                                     "does not have a vocab." % vocab_share)
                if bos_token == vocabs[vocab_share].bos_token and \
                        eos_token == vocabs[vocab_share].eos_token:
                    vocab = vocabs[vocab_share]
                else:
                    vocab = Vocab(hparams[vocab_share]["vocab_file"],
                                  bos_token=bos_token,
                                  eos_token=eos_token)
            else:
                vocab = Vocab(hparams_i["vocab_file"],
                              bos_token=bos_token,
                              eos_token=eos_token)
            vocabs.append(vocab)

        return vocabs
Example #4
0
 def make_vocab(hparams):
     """Reads vocab file and returns an instance of
     :class:`texar.data.Vocab`.
     """
     bos_token = utils.default_str(hparams["bos_token"], SpecialTokens.BOS)
     eos_token = utils.default_str(hparams["eos_token"], SpecialTokens.EOS)
     vocab = Vocab(hparams["vocab_file"],
                   bos_token=bos_token,
                   eos_token=eos_token)
     return vocab
Example #5
0
    def test_map_ids_to_strs(self):
        """Tests :func:`texar.utils.map_ids_to_strs`.
        """
        vocab_list = ['word', '词']
        vocab_file = tempfile.NamedTemporaryFile()
        vocab_file.write('\n'.join(vocab_list).encode("utf-8"))
        vocab_file.flush()
        vocab = Vocab(vocab_file.name)

        text = [['<BOS>', 'word', '词', '<EOS>', '<PAD>'],
                ['word', '词', 'word', '词', '<PAD>']]
        text = np.asarray(text)
        ids = vocab.map_tokens_to_ids_py(text)

        ids = ids.tolist()
        text_ = utils.map_ids_to_strs(ids, vocab)

        self.assertEqual(text_[0], 'word 词')
        self.assertEqual(text_[1], 'word 词 word 词')
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        if self._hparams.dataset.variable_utterance:
            raise NotImplementedError

        # Create vocabulary
        self._bos_token = self._hparams.dataset.bos_token
        self._eos_token = self._hparams.dataset.eos_token
        self._other_transforms = self._hparams.dataset.other_transformations
        bos = utils.default_str(self._bos_token, SpecialTokens.BOS)
        eos = utils.default_str(self._eos_token, SpecialTokens.EOS)
        self._vocab = Vocab(self._hparams.dataset.vocab_file,
                            bos_token=bos,
                            eos_token=eos)

        # Create embedding
        self._embedding = self.make_embedding(
            self._hparams.dataset.embedding_init,
            self._vocab.token_to_id_map_py)

        self._delimiter = self._hparams.dataset.delimiter
        self._max_seq_length = self._hparams.dataset.max_seq_length
        self._length_filter_mode = _LengthFilterMode(
            self._hparams.dataset.length_filter_mode)
        self._pad_length = self._max_seq_length
        if self._pad_length is not None:
            self._pad_length += sum(
                int(x != '') for x in [self._bos_token, self._eos_token])

        if (self._length_filter_mode is _LengthFilterMode.DISCARD
                and self._max_seq_length is not None):
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type,
                delimiter=self._delimiter,
                max_length=self._max_seq_length)
        else:
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type)

        super().__init__(data_source, hparams, device=device)
Example #7
0
    def _construct(cls,
                   hparams,
                   device: Optional[torch.device] = None,
                   vocab: Optional[Vocab] = None,
                   embedding: Optional[Vocab] = None):
        mono_text_data = cls.__new__(cls)
        mono_text_data._hparams = HParams(hparams,
                                          mono_text_data.default_hparams())
        if mono_text_data._hparams.dataset.variable_utterance:
            raise NotImplementedError

        dataset = mono_text_data._hparams.dataset
        mono_text_data._other_transforms = dataset.other_transformations

        # Create vocabulary
        if vocab is not None:
            mono_text_data._vocab = vocab
            mono_text_data._bos_token = vocab.bos_token
            mono_text_data._eos_token = vocab.eos_token
        else:
            mono_text_data._bos_token = dataset.bos_token
            mono_text_data._eos_token = dataset.eos_token
            bos = utils.default_str(mono_text_data._bos_token,
                                    SpecialTokens.BOS)
            eos = utils.default_str(mono_text_data._eos_token,
                                    SpecialTokens.EOS)
            mono_text_data._vocab = Vocab(dataset.vocab_file,
                                          bos_token=bos,
                                          eos_token=eos)

        # Create embedding
        if embedding is not None:
            mono_text_data._embedding = embedding
        else:
            mono_text_data._embedding = mono_text_data.make_embedding(
                dataset.embedding_init,
                mono_text_data._vocab.token_to_id_map_py)

        mono_text_data._delimiter = dataset.delimiter
        mono_text_data._max_seq_length = dataset.max_seq_length
        mono_text_data._length_filter_mode = _LengthFilterMode(
            mono_text_data._hparams.dataset.length_filter_mode)
        mono_text_data._pad_length = mono_text_data._max_seq_length
        if mono_text_data._pad_length is not None:
            mono_text_data._pad_length += sum(
                int(x != '') for x in
                [mono_text_data._bos_token, mono_text_data._eos_token])

        data_source: SequenceDataSource[str] = SequenceDataSource([])
        super(MonoTextData, mono_text_data).__init__(source=data_source,
                                                     hparams=hparams,
                                                     device=device)

        return mono_text_data
    def make_vocab(src_hparams, tgt_hparams):
        """Reads vocab files and returns source vocab and target vocab.

        Args:
            src_hparams (dict or HParams): Hyperparameters of source dataset.
            tgt_hparams (dict or HParams): Hyperparameters of target dataset.

        Returns:
            A pair of :class:`texar.data.Vocab` instances. The two instances
            may be the same objects if source and target vocabs are shared
            and have the same other configs.
        """
        src_vocab = MonoTextData.make_vocab(src_hparams)

        if tgt_hparams["processing_share"]:
            tgt_bos_token = src_hparams["bos_token"]
            tgt_eos_token = src_hparams["eos_token"]
        else:
            tgt_bos_token = tgt_hparams["bos_token"]
            tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == src_vocab.bos_token and \
                    tgt_eos_token == src_vocab.eos_token:
                tgt_vocab = src_vocab
            else:
                tgt_vocab = Vocab(src_hparams["vocab_file"],
                                  bos_token=tgt_bos_token,
                                  eos_token=tgt_eos_token)
        else:
            tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                              bos_token=tgt_bos_token,
                              eos_token=tgt_eos_token)

        return src_vocab, tgt_vocab
class MonoTextData(TextDataBase[str, List[str]]):
    r"""Text data processor that reads single set of text files. This can be
    used for, e.g., language models, auto-encoders, etc.

    Args:
        hparams: A `dict` or instance of :class:`~texar.HParams` containing
            hyperparameters. See :meth:`default_hparams` for the defaults.

    By default, the processor reads raw data files, performs tokenization,
    batching and other pre-processing steps, and results in a TF Dataset
    whose element is a python `dict` including three fields:

    "text":
        A list of ``[batch_size]`` elements each containing a list of
        **raw** text tokens of the sequences. Short sequences in the batch
        are padded with **empty string**. By default only ``EOS`` token is
        appended to each sequence. Out-of-vocabulary tokens are **NOT**
        replaced with ``UNK``.
    "text_ids":
        A list of ``[batch_size]`` elements each containing a list of token
        indexes of source sequences in the batch.
    "length":
        A list of ``[batch_size]`` elements of ints containing the length
        of each source sequence in the batch (including ``BOS`` and ``EOS``
        if added).

    The above field names can be accessed through :attr:`text_name`,
    :attr:`text_id_name`, :attr:`length_name`.

    Example:

        .. code-block:: python

            hparams={
                'dataset': { 'files': 'data.txt', 'vocab_file': 'vocab.txt' },
                'batch_size': 1
            }
            data = MonoTextData(hparams)
            iterator = DataIterator(data)
            for batch in iterator:
                # batch contains the following
                # batch_ == {
                #    'text': [['<BOS>', 'example', 'sequence', '<EOS>']],
                #    'text_ids': [[1, 5, 10, 2]],
                #    'length': [4]
                # }
    """

    _delimiter: str
    _bos: Optional[str]
    _eos: Optional[str]
    _max_seq_length: Optional[int]
    _should_pad: bool

    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        if self._hparams.dataset.variable_utterance:
            raise NotImplementedError

        # Create vocabulary
        self._bos_token = self._hparams.dataset.bos_token
        self._eos_token = self._hparams.dataset.eos_token
        self._other_transforms = self._hparams.dataset.other_transformations
        bos = utils.default_str(self._bos_token, SpecialTokens.BOS)
        eos = utils.default_str(self._eos_token, SpecialTokens.EOS)
        self._vocab = Vocab(self._hparams.dataset.vocab_file,
                            bos_token=bos, eos_token=eos)

        # Create embedding
        self._embedding = self.make_embedding(
            self._hparams.dataset.embedding_init,
            self._vocab.token_to_id_map_py)

        self._delimiter = self._hparams.dataset.delimiter
        self._max_seq_length = self._hparams.dataset.max_seq_length
        self._length_filter_mode = _LengthFilterMode(
            self._hparams.dataset.length_filter_mode)
        self._pad_length = self._max_seq_length
        if self._pad_length is not None:
            self._pad_length += sum(int(x != '')
                                    for x in [self._bos_token, self._eos_token])

        if (self._length_filter_mode is _LengthFilterMode.DISCARD and
                self._max_seq_length is not None):
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type,
                delimiter=self._delimiter,
                max_length=self._max_seq_length)
        else:
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type)

        super().__init__(data_source, hparams, device=device)

    @staticmethod
    def default_hparams():
        r"""Returns a dictionary of default hyperparameters:

        .. code-block:: python

            {
                # (1) Hyperparameters specific to text dataset
                "dataset": {
                    "files": [],
                    "compression_type": None,
                    "vocab_file": "",
                    "embedding_init": {},
                    "delimiter": " ",
                    "max_seq_length": None,
                    "length_filter_mode": "truncate",
                    "pad_to_max_seq_length": False,
                    "bos_token": "<BOS>"
                    "eos_token": "<EOS>"
                    "other_transformations": [],
                    "variable_utterance": False,
                    "utterance_delimiter": "|||",
                    "max_utterance_cnt": 5,
                    "data_name": None,
                }
                # (2) General hyperparameters
                "num_epochs": 1,
                "batch_size": 64,
                "allow_smaller_final_batch": True,
                "shuffle": True,
                "shuffle_buffer_size": None,
                "shard_and_shuffle": False,
                "num_parallel_calls": 1,
                "prefetch_buffer_size": 0,
                "max_dataset_size": -1,
                "seed": None,
                "name": "mono_text_data",
                # (3) Bucketing
                "bucket_boundaries": [],
                "bucket_batch_sizes": None,
                "bucket_length_fn": None,
            }

        Here:

        1. For the hyperparameters in the :attr:`"dataset"` field:

          "files" : str or list
              A (list of) text file path(s).

              Each line contains a single text sequence.

          "compression_type" : str, optional
              One of ``None`` (no compression), ``"ZLIB"``, or ``"GZIP"``.

          "vocab_file": str
              Path to vocabulary file. Each line of the file should contain
              one vocabulary token.

              Used to create an instance of :class:`~texar.data.Vocab`.

          "embedding_init" : dict
              The hyperparameters for pre-trained embedding loading and
              initialization.

              The structure and default values are defined in
              :meth:`texar.data.Embedding.default_hparams`.

          "delimiter" : str
              The delimiter to split each line of the text files into tokens.

          "max_seq_length" : int, optional
              Maximum length of output sequences. Data samples exceeding the
              length will be truncated or discarded according to
              :attr:`"length_filter_mode"`. The length does not include
              any added
              :attr:`"bos_token"` or :attr:`"eos_token"`. If `None` (default),
              no filtering is performed.

          "length_filter_mode" : str
              Either ``"truncate"`` or ``"discard"``. If ``"truncate"``
              (default), tokens exceeding :attr:`"max_seq_length"` will be
              truncated.
              If ``"discard"``, data samples longer than
              :attr:`"max_seq_length"` will be discarded.

          "pad_to_max_seq_length" : bool
              If `True`, pad all data instances to length
              :attr:`"max_seq_length"`.
              Raises error if :attr:`"max_seq_length"` is not provided.

          "bos_token" : str
              The Begin-Of-Sequence token prepended to each sequence.

              Set to an empty string to avoid prepending.

          "eos_token" : str
              The End-Of-Sequence token appended to each sequence.

              Set to an empty string to avoid appending.

          "other_transformations" : list
              A list of transformation functions or function names/paths to
              further transform each single data instance.

              (More documentations to be added.)

          "variable_utterance" : bool
              If `True`, each line of the text file is considered to contain
              multiple sequences (utterances) separated by
              :attr:`"utterance_delimiter"`.

              For example, in dialog data, each line can contain a series of
              dialog history utterances. See the example in
              `examples/hierarchical_dialog` for a use case.

              .. warning::
                  Variable utterances is not yet supported. This option (and
                  related ones below) will be ignored.

          "utterance_delimiter" : str
              The delimiter to split over utterance level. Should not be the
              same with :attr:`"delimiter"`. Used only when
              :attr:`"variable_utterance"` is ``True``.

          "max_utterance_cnt" : int
              Maximally allowed number of utterances in a data instance.
              Extra utterances are truncated out.

          "data_name" : str
              Name of the dataset.

        2. For the **general** hyperparameters, see
        :meth:`texar.data.DataBase.default_hparams` for details.

        3. **Bucketing** is to group elements of the dataset
        together by length and then pad and batch. For bucketing
        hyperparameters:

          "bucket_boundaries" : list
              An int list containing the upper length boundaries of the
              buckets.

              Set to an empty list (default) to disable bucketing.

          "bucket_batch_sizes" : list
              An int list containing batch size per bucket. Length should be
              `len(bucket_boundaries) + 1`.

              If `None`, every bucket will have the same batch size specified
              in :attr:`batch_size`.

          "bucket_length_fn" : str or callable
              Function maps dataset element to ``int``, determines
              the length of the element.

              This can be a function, or the name or full module path to the
              function. If function name is given, the function must be in the
              :mod:`texar.custom` module.

              If `None` (default), length is determined by the number of
              tokens (including BOS and EOS if added) of the element.

          .. warning::
              Bucketing is not yet supported. These options will be ignored.

        """
        hparams = TextDataBase.default_hparams()
        hparams["name"] = "mono_text_data"
        hparams.update({
            "dataset": _default_mono_text_dataset_hparams()
        })
        return hparams

    @staticmethod
    def make_embedding(emb_hparams, token_to_id_map):
        r"""Optionally loads embedding from file (if provided), and returns
        an instance of :class:`texar.data.Embedding`.
        """
        embedding = None
        if emb_hparams["file"] is not None and len(emb_hparams["file"]) > 0:
            embedding = Embedding(token_to_id_map, emb_hparams)
        return embedding

    def _process(self, raw_example: str) -> List[str]:
        # `_process` truncates sentences and appends BOS/EOS tokens.
        words = raw_example.split(self._delimiter)
        if (self._max_seq_length is not None and
                len(words) > self._max_seq_length):
            if self._length_filter_mode is _LengthFilterMode.TRUNC:
                words = words[:self._max_seq_length]

        if self._bos_token != '':
            words.insert(0, self._bos_token)
        if self._eos_token != '':
            words.append(self._eos_token)

        # Apply the "other transformations".
        for transform in self._other_transforms:
            words = transform(words)

        return words

    def _collate(self, examples: List[List[str]]) -> Batch:
        # For `MonoTextData`, each example is represented as a list of strings.
        # `_collate` takes care of padding and numericalization.

        # If `pad_length` is `None`, pad to the longest sentence in the batch.
        text_ids = [self._vocab.map_tokens_to_ids_py(sent) for sent in examples]
        text_ids, lengths = padded_batch(text_ids, self._pad_length,
                                         pad_value=self._vocab.pad_token_id)
        # Also pad the examples
        pad_length = self._pad_length or max(lengths)
        examples = [
            sent + [''] * (pad_length - len(sent))
            if len(sent) < pad_length else sent
            for sent in examples
        ]

        text_ids = torch.from_numpy(text_ids).to(device=self.device)
        lengths = torch.tensor(lengths, dtype=torch.long, device=self.device)
        return Batch(len(examples), text=examples,
                     text_ids=text_ids, length=lengths)

    def list_items(self) -> List[str]:
        r"""Returns the list of item names that the data can produce.

        Returns:
            A list of strings.
        """
        items = ['text', 'text_ids', 'length']
        data_name = self._hparams.dataset.data_name
        if data_name is not None:
            items = [data_name + '_' + item for item in items]
        return items

    @property
    def vocab(self) -> Vocab:
        r"""The vocabulary, an instance of :class:`~texar.data.Vocab`.
        """
        return self._vocab

    def text_name(self):
        r"""The name for the text field"""
        if self.hparams.dataset["data_name"]:
            name = "{}_text".format(self.hparams.dataset["data_name"])
        else:
            name = "text"
        return name

    @property
    def text_id_name(self):
        r"""The name for text ids"""
        if self.hparams.dataset["data_name"]:
            name = "{}_text_ids".format(self.hparams.dataset["data_name"])
        else:
            name = "text_ids"
        return name

    @property
    def length_name(self):
        r"""The name for text length"""
        if self.hparams.dataset["data_name"]:
            name = "{}_length".format(self.hparams.dataset["data_name"])
        else:
            name = "length"
        return name

    @property
    def embedding_init_value(self):
        r"""The `Tensor` containing the embedding value loaded from file.
        `None` if embedding is not specified.
        """
        if self._embedding is None:
            return None
        return self._embedding.word_vecs
Example #10
0
class PairedTextData(TextDataBase[Tuple[str, str],
                                  Tuple[List[str], List[str]]]):
    r"""Text data processor that reads parallel source and target text.
    This can be used in, e.g., seq2seq models.

    Args:
        hparams (dict): Hyperparameters. See :meth:`default_hparams` for the
            defaults.

    By default, the processor reads raw data files, performs tokenization,
    batching and other pre-processing steps, and results in a Dataset
    whose element is a python `dict` including six fields:

    "source_text":
        A list of ``[batch_size]`` elements each containing a list of
        **raw** text tokens of source sequences. Short sequences in the
        batch are padded with **empty string**. By default only ``EOS``
        token is appended to each sequence. Out-of-vocabulary tokens are
        **NOT** replaced with ``UNK``.
    "source_text_ids":
        A list of ``[batch_size]`` elements each containing a list of token
        indexes of source sequences in the batch.
    "source_length":
        A list of ``[batch_size]`` elements of ints containing the length
        of each source sequence in the batch.
    "target_text":
        A list same as "source_text" but for target sequences. By default
        both BOS and EOS are added.
    "target_text_ids":
        A list same as "source_text_ids" but for target sequences.
    "target_length":
        An list same as "source_length" but for target sequences.

    The above field names can be accessed through :attr:`source_text_name`,
    :attr:`source_text_id_name`, :attr:`source_length_name`, and those prefixed
    with ``target_``, respectively.

    Example:

    .. code-block:: python

        hparams={
            'source_dataset': {'files': 's', 'vocab_file': 'vs'},
            'target_dataset': {'files': ['t1', 't2'], 'vocab_file': 'vt'},
            'batch_size': 1
        }
        data = PairedTextData(hparams)
        iterator = DataIterator(data)

        for batch in iterator:
            # batch contains the following
            # batch_ == {
            #    'source_text': [['source', 'sequence', '<EOS>']],
            #    'source_text_ids': [[5, 10, 2]],
            #    'source_length': [3]
            #    'target_text': [['<BOS>', 'target', 'sequence', '1',
                                '<EOS>']],
            #    'target_text_ids': [[1, 6, 10, 20, 2]],
            #    'target_length': [5]
            # }

    """

    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(int(x is not None and x != '')
                                        for x in [src_hparams.bos_token,
                                                  src_hparams.eos_token])

        src_data_source = TextLineDataSource(src_hparams.files,
                                             compression_type=
                                             src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(int(x is not None and x != '')
                                        for x in [tgt_hparams.bos_token,
                                                  tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(tgt_hparams.files,
                                             compression_type=
                                             tgt_hparams.compression_type)

        data_source: DataSource[Tuple[str, str]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and
            self._src_max_seq_length is not None) or \
                (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and
                 self._tgt_length_filter_mode is not None):
            max_source_length = self._src_max_seq_length if \
                self._src_max_seq_length is not None else np.inf
            max_tgt_length = self._tgt_max_seq_length if \
                self._tgt_max_seq_length is not None else np.inf

            def filter_fn(raw_example):
                return len(raw_example[0].split(self._src_delimiter)) \
                       <= max_source_length and \
                       len(raw_example[1].split(self._tgt_delimiter)) \
                       <= max_tgt_length

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)

    @staticmethod
    def default_hparams():
        r"""Returns a dictionary of default hyperparameters.

        .. code-block:: python

            {
                # (1) Hyperparams specific to text dataset
                "source_dataset": {
                    "files": [],
                    "compression_type": None,
                    "vocab_file": "",
                    "embedding_init": {},
                    "delimiter": " ",
                    "max_seq_length": None,
                    "length_filter_mode": "truncate",
                    "pad_to_max_seq_length": False,
                    "bos_token": None,
                    "eos_token": "<EOS>",
                    "other_transformations": [],
                    "variable_utterance": False,
                    "utterance_delimiter": "|||",
                    "max_utterance_cnt": 5,
                    "data_name": "source",
                },
                "target_dataset": {
                    # ...
                    # Same fields are allowed as in "source_dataset" with the
                    # same default values, except the
                    # following new fields/values:
                    "bos_token": "<BOS>"
                    "vocab_share": False,
                    "embedding_init_share": False,
                    "processing_share": False,
                    "data_name": "target"
                }
                # (2) General hyperparams
                "num_epochs": 1,
                "batch_size": 64,
                "allow_smaller_final_batch": True,
                "shuffle": True,
                "shuffle_buffer_size": None,
                "shard_and_shuffle": False,
                "num_parallel_calls": 1,
                "prefetch_buffer_size": 0,
                "max_dataset_size": -1,
                "seed": None,
                "name": "paired_text_data",
                # (3) Bucketing
                "bucket_boundaries": [],
                "bucket_batch_sizes": None,
                "bucket_length_fn": None,
            }

        Here:

        1. Hyperparameters in the :attr:`"source_dataset"` and
           attr:`"target_dataset"` fields have the same definition as those
           in :meth:`texar.data.MonoTextData.default_hparams`, for source and
           target text, respectively.

           For the new hyperparameters in "target_dataset":

           "vocab_share" : bool
               Whether to share the vocabulary of source.
               If `True`, the vocab file of target is ignored.

           "embedding_init_share" : bool
               Whether to share the embedding initial value of source. If
               `True`, :attr:`"embedding_init"` of target is ignored.

              :attr:`"vocab_share"` must be true to share the embedding
              initial value.

           "processing_share" : bool
               Whether to share the processing configurations of source,
               including
               "delimiter", "bos_token", "eos_token", and
               "other_transformations".

        2. For the **general** hyperparameters, see
           :meth:`texar.data.DataBase.default_hparams` for details.

        3. For **bucketing** hyperparameters, see
           :meth:`texar.data.MonoTextData.default_hparams` for details, except
           that the default bucket_length_fn is the maximum sequence length
           of source and target sequences.

           .. warning::
               Bucketing is not yet supported. These options will be ignored.

        """
        hparams = TextDataBase.default_hparams()
        hparams["name"] = "paired_text_data"
        hparams.update(_default_paired_text_dataset_hparams())
        return hparams

    @staticmethod
    def make_embedding(src_emb_hparams, src_token_to_id_map,
                       tgt_emb_hparams=None, tgt_token_to_id_map=None,
                       emb_init_share=False):
        r"""Optionally loads source and target embeddings from files
        (if provided), and returns respective :class:`texar.data.Embedding`
        instances.
        """
        src_embedding = MonoTextData.make_embedding(src_emb_hparams,
                                                    src_token_to_id_map)

        if emb_init_share:
            tgt_embedding = src_embedding
        else:
            tgt_emb_file = tgt_emb_hparams["file"]
            tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams)

        return src_embedding, tgt_embedding

    def _process(self, raw_example: Tuple[str, str]) -> \
            Tuple[List[str], List[str]]:
        # `_process` truncates sentences and appends BOS/EOS tokens.
        src_words = raw_example[0].split(self._src_delimiter)
        if (self._src_max_seq_length is not None and
                len(src_words) > self._src_max_seq_length):
            if self._src_length_filter_mode is _LengthFilterMode.TRUNC:
                src_words = src_words[:self._src_max_seq_length]

        if self._src_bos_token is not None and self._src_bos_token != '':
            src_words.insert(0, self._src_bos_token)
        if self._src_eos_token is not None and self._src_eos_token != '':
            src_words.append(self._src_eos_token)

        # apply the transformations to source
        for transform in self._src_transforms:
            src_words = transform(src_words)

        tgt_words = raw_example[1].split(self._tgt_delimiter)
        if (self._tgt_max_seq_length is not None and
                len(tgt_words) > self._tgt_max_seq_length):
            if self._tgt_length_filter_mode is _LengthFilterMode.TRUNC:
                tgt_words = tgt_words[:self._tgt_max_seq_length]

        if self._tgt_bos_token is not None and self._tgt_bos_token != '':
            tgt_words.insert(0, self._tgt_bos_token)
        if self._tgt_eos_token is not None and self._tgt_eos_token != '':
            tgt_words.append(self._tgt_eos_token)

        # apply the transformations to target
        for transform in self._tgt_transforms:
            tgt_words = transform(tgt_words)

        return src_words, tgt_words

    @staticmethod
    def _get_name_prefix(src_hparams, tgt_hparams):
        name_prefix = [
            src_hparams["data_name"], tgt_hparams["data_name"]]
        if name_prefix[0] == name_prefix[1]:
            raise ValueError("'data_name' of source and target "
                             "datasets cannot be the same.")
        return name_prefix

    def _collate(self, examples: List[Tuple[List[str], List[str]]]) -> Batch:
        # For `PairedTextData`, each example is represented as a tuple of list
        # of strings.
        # `_collate` takes care of padding and numericalization.

        # If `pad_length` is `None`, pad to the longest sentence in the batch.
        src_examples = [example[0] for example in examples]
        source_ids = [self._src_vocab.map_tokens_to_ids_py(sent) for sent
                      in src_examples]
        source_ids, source_lengths = \
            padded_batch(source_ids,
                         self._src_pad_length,
                         pad_value=self._src_vocab.pad_token_id)
        src_pad_length = self._src_pad_length or max(source_lengths)
        src_examples = [
            sent + [''] * (src_pad_length - len(sent))
            if len(sent) < src_pad_length else sent
            for sent in src_examples
        ]

        source_ids = torch.from_numpy(source_ids).to(device=self.device)
        source_lengths = torch.tensor(source_lengths, dtype=torch.long,
                                      device=self.device)

        tgt_examples = [example[1] for example in examples]
        target_ids = [self._tgt_vocab.map_tokens_to_ids_py(sent) for sent
                      in tgt_examples]
        target_ids, target_lengths = \
            padded_batch(target_ids,
                         self._tgt_pad_length,
                         pad_value=self._tgt_vocab.pad_token_id)
        tgt_pad_length = self._tgt_pad_length or max(target_lengths)
        tgt_examples = [
            sent + [''] * (tgt_pad_length - len(sent))
            if len(sent) < tgt_pad_length else sent
            for sent in tgt_examples
        ]

        target_ids = torch.from_numpy(target_ids).to(device=self.device)
        target_lengths = torch.tensor(target_lengths, dtype=torch.long,
                                      device=self.device)

        return Batch(len(examples), source_text=src_examples,
                     source_text_ids=source_ids, source_length=source_lengths,
                     target_text=tgt_examples, target_text_ids=target_ids,
                     target_length=target_lengths)

    def list_items(self) -> List[str]:
        r"""Returns the list of item names that the data can produce.

        Returns:
            A list of strings.
        """
        items = ['text', 'text_ids', 'length']
        src_name = self._hparams.source_dataset['data_name']
        tgt_name = self._hparams.target_dataset['data_name']

        if src_name is not None:
            src_items = [src_name + '_' + item for item in items]
        else:
            src_items = items

        if tgt_name is not None:
            tgt_items = [tgt_name + '_' + item for item in items]
        else:
            tgt_items = items

        return src_items + tgt_items

    @property
    def dataset(self):
        r"""The dataset.
        """
        return self._source

    @property
    def vocab(self):
        r"""A pair instances of :class:`~texar.data.Vocab` that are source
        and target vocabs, respectively.
        """
        return self._src_vocab, self._tgt_vocab

    @property
    def source_vocab(self):
        r"""The source vocab, an instance of :class:`~texar.data.Vocab`.
        """
        return self._src_vocab

    @property
    def target_vocab(self):
        r"""The target vocab, an instance of :class:`~texar.data.Vocab`.
        """
        return self._tgt_vocab

    @property
    def source_text_name(self):
        r"""The name for source text"""
        name = "{}_text".format(self.hparams.source_dataset["data_name"])
        return name

    @property
    def source_text_id_name(self):
        r"""The name for source text id"""
        name = "{}_text_ids".format(self.hparams.source_dataset["data_name"])
        return name

    @property
    def source_length_name(self):
        r"""The name for source length"""
        name = "{}_length".format(self.hparams.source_dataset["data_name"])
        return name

    @property
    def target_text_name(self):
        r"""The name for target text"""
        name = "{}_text".format(self.hparams.target_dataset["data_name"])
        return name

    @property
    def target_text_id_name(self):
        r"""The name for target text id"""
        name = "{}_text_ids".format(self.hparams.target_dataset["data_name"])
        return name

    @property
    def target_length_name(self):
        r"""The name for target length"""
        name = "{}_length".format(self.hparams.target_dataset["data_name"])
        return name

    def embedding_init_value(self):
        r"""A pair of `Tensor` containing the embedding values of source and
        target data loaded from file.
        """
        src_emb = self.hparams.source_dataset["embedding_init"]
        tgt_emb = self.hparams.target_dataser["embedding_init"]
        return src_emb, tgt_emb
Example #11
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(int(x is not None and x != '')
                                        for x in [src_hparams.bos_token,
                                                  src_hparams.eos_token])

        src_data_source = TextLineDataSource(src_hparams.files,
                                             compression_type=
                                             src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(int(x is not None and x != '')
                                        for x in [tgt_hparams.bos_token,
                                                  tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(tgt_hparams.files,
                                             compression_type=
                                             tgt_hparams.compression_type)

        data_source: DataSource[Tuple[str, str]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and
            self._src_max_seq_length is not None) or \
                (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and
                 self._tgt_length_filter_mode is not None):
            max_source_length = self._src_max_seq_length if \
                self._src_max_seq_length is not None else np.inf
            max_tgt_length = self._tgt_max_seq_length if \
                self._tgt_max_seq_length is not None else np.inf

            def filter_fn(raw_example):
                return len(raw_example[0].split(self._src_delimiter)) \
                       <= max_source_length and \
                       len(raw_example[1].split(self._tgt_delimiter)) \
                       <= max_tgt_length

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)