Exemplo n.º 1
0
 def make_vocab(hparams):
     """Reads vocab file and returns an instance of
     :class:`texar.data.Vocab`.
     """
     bos_token = utils.default_str(hparams["bos_token"], SpecialTokens.BOS)
     eos_token = utils.default_str(hparams["eos_token"], SpecialTokens.EOS)
     vocab = Vocab(hparams["vocab_file"],
                   bos_token=bos_token,
                   eos_token=eos_token)
     return vocab
    def make_vocab(hparams):
        r"""Makes a list of vocabs based on the hyperparameters.

        Args:
            hparams (list): A list of dataset hyperparameters.

        Returns:
            A list of :class:`texar.data.Vocab` instances. Some instances
            may be the same objects if they are set to be shared and have
            the same other configurations.
        """
        if not isinstance(hparams, (list, tuple)):
            hparams = [hparams]

        vocabs = []
        for i, hparams_i in enumerate(hparams):
            if not _is_text_data(hparams_i["data_type"]):
                vocabs.append(None)
                continue

            proc_share = hparams_i["processing_share_with"]
            if proc_share is not None:
                bos_token = hparams[proc_share]["bos_token"]
                eos_token = hparams[proc_share]["eos_token"]
            else:
                bos_token = hparams_i["bos_token"]
                eos_token = hparams_i["eos_token"]
            bos_token = utils.default_str(
                bos_token, SpecialTokens.BOS)
            eos_token = utils.default_str(
                eos_token, SpecialTokens.EOS)

            vocab_share = hparams_i["vocab_share_with"]
            if vocab_share is not None:
                if vocab_share >= i:
                    MultiAlignedData._raise_sharing_error(
                        i, vocab_share, "vocab_share_with")
                if vocabs[vocab_share] is None:
                    raise ValueError("Cannot share vocab with dataset %d which "
                                     "does not have a vocab." % vocab_share)
                if bos_token == vocabs[vocab_share].bos_token and \
                        eos_token == vocabs[vocab_share].eos_token:
                    vocab = vocabs[vocab_share]
                else:
                    vocab = Vocab(hparams[vocab_share]["vocab_file"],
                                  bos_token=bos_token,
                                  eos_token=eos_token)
            else:
                vocab = Vocab(hparams_i["vocab_file"],
                              bos_token=bos_token,
                              eos_token=eos_token)
            vocabs.append(vocab)

        return vocabs
Exemplo n.º 3
0
    def _construct(cls,
                   hparams,
                   device: Optional[torch.device] = None,
                   vocab: Optional[Vocab] = None,
                   embedding: Optional[Vocab] = None):
        mono_text_data = cls.__new__(cls)
        mono_text_data._hparams = HParams(hparams,
                                          mono_text_data.default_hparams())
        if mono_text_data._hparams.dataset.variable_utterance:
            raise NotImplementedError

        dataset = mono_text_data._hparams.dataset
        mono_text_data._other_transforms = dataset.other_transformations

        # Create vocabulary
        if vocab is not None:
            mono_text_data._vocab = vocab
            mono_text_data._bos_token = vocab.bos_token
            mono_text_data._eos_token = vocab.eos_token
        else:
            mono_text_data._bos_token = dataset.bos_token
            mono_text_data._eos_token = dataset.eos_token
            bos = utils.default_str(mono_text_data._bos_token,
                                    SpecialTokens.BOS)
            eos = utils.default_str(mono_text_data._eos_token,
                                    SpecialTokens.EOS)
            mono_text_data._vocab = Vocab(dataset.vocab_file,
                                          bos_token=bos,
                                          eos_token=eos)

        # Create embedding
        if embedding is not None:
            mono_text_data._embedding = embedding
        else:
            mono_text_data._embedding = mono_text_data.make_embedding(
                dataset.embedding_init,
                mono_text_data._vocab.token_to_id_map_py)

        mono_text_data._delimiter = dataset.delimiter
        mono_text_data._max_seq_length = dataset.max_seq_length
        mono_text_data._length_filter_mode = _LengthFilterMode(
            mono_text_data._hparams.dataset.length_filter_mode)
        mono_text_data._pad_length = mono_text_data._max_seq_length
        if mono_text_data._pad_length is not None:
            mono_text_data._pad_length += sum(
                int(x != '') for x in
                [mono_text_data._bos_token, mono_text_data._eos_token])

        data_source: SequenceDataSource[str] = SequenceDataSource([])
        super(MonoTextData, mono_text_data).__init__(source=data_source,
                                                     hparams=hparams,
                                                     device=device)

        return mono_text_data
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        if self._hparams.dataset.variable_utterance:
            raise NotImplementedError

        # Create vocabulary
        self._bos_token = self._hparams.dataset.bos_token
        self._eos_token = self._hparams.dataset.eos_token
        self._other_transforms = self._hparams.dataset.other_transformations
        bos = utils.default_str(self._bos_token, SpecialTokens.BOS)
        eos = utils.default_str(self._eos_token, SpecialTokens.EOS)
        self._vocab = Vocab(self._hparams.dataset.vocab_file,
                            bos_token=bos,
                            eos_token=eos)

        # Create embedding
        self._embedding = self.make_embedding(
            self._hparams.dataset.embedding_init,
            self._vocab.token_to_id_map_py)

        self._delimiter = self._hparams.dataset.delimiter
        self._max_seq_length = self._hparams.dataset.max_seq_length
        self._length_filter_mode = _LengthFilterMode(
            self._hparams.dataset.length_filter_mode)
        self._pad_length = self._max_seq_length
        if self._pad_length is not None:
            self._pad_length += sum(
                int(x != '') for x in [self._bos_token, self._eos_token])

        if (self._length_filter_mode is _LengthFilterMode.DISCARD
                and self._max_seq_length is not None):
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type,
                delimiter=self._delimiter,
                max_length=self._max_seq_length)
        else:
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type)

        super().__init__(data_source, hparams, device=device)
Exemplo n.º 5
0
    def make_vocab(src_hparams, tgt_hparams):
        """Reads vocab files and returns source vocab and target vocab.

        Args:
            src_hparams (dict or HParams): Hyperparameters of source dataset.
            tgt_hparams (dict or HParams): Hyperparameters of target dataset.

        Returns:
            A pair of :class:`texar.data.Vocab` instances. The two instances
            may be the same objects if source and target vocabs are shared
            and have the same other configs.
        """
        src_vocab = MonoTextData.make_vocab(src_hparams)

        if tgt_hparams["processing_share"]:
            tgt_bos_token = src_hparams["bos_token"]
            tgt_eos_token = src_hparams["eos_token"]
        else:
            tgt_bos_token = tgt_hparams["bos_token"]
            tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == src_vocab.bos_token and \
                    tgt_eos_token == src_vocab.eos_token:
                tgt_vocab = src_vocab
            else:
                tgt_vocab = Vocab(src_hparams["vocab_file"],
                                  bos_token=tgt_bos_token,
                                  eos_token=tgt_eos_token)
        else:
            tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                              bos_token=tgt_bos_token,
                              eos_token=tgt_eos_token)

        return src_vocab, tgt_vocab
Exemplo n.º 6
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(int(x is not None and x != '')
                                        for x in [src_hparams.bos_token,
                                                  src_hparams.eos_token])

        src_data_source = TextLineDataSource(src_hparams.files,
                                             compression_type=
                                             src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(int(x is not None and x != '')
                                        for x in [tgt_hparams.bos_token,
                                                  tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(tgt_hparams.files,
                                             compression_type=
                                             tgt_hparams.compression_type)

        data_source: DataSource[Tuple[str, str]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and
            self._src_max_seq_length is not None) or \
                (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and
                 self._tgt_length_filter_mode is not None):
            max_source_length = self._src_max_seq_length if \
                self._src_max_seq_length is not None else np.inf
            max_tgt_length = self._tgt_max_seq_length if \
                self._tgt_max_seq_length is not None else np.inf

            def filter_fn(raw_example):
                return len(raw_example[0].split(self._src_delimiter)) \
                       <= max_source_length and \
                       len(raw_example[1].split(self._tgt_delimiter)) \
                       <= max_tgt_length

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)