コード例 #1
0
    def __init__(self,
                 hparams,
                 device: Optional[torch.device] = None,
                 vocab: Optional[Vocab] = None,
                 embedding: Optional[Embedding] = None,
                 data_source: Optional[DataSource] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        if self._hparams.dataset.variable_utterance:
            raise NotImplementedError

        # Create vocabulary
        self._bos_token = self._hparams.dataset.bos_token
        self._eos_token = self._hparams.dataset.eos_token
        self._other_transforms = self._hparams.dataset.other_transformations
        bos = utils.default_str(self._bos_token, SpecialTokens.BOS)
        eos = utils.default_str(self._eos_token, SpecialTokens.EOS)
        if vocab is None:
            self._vocab = Vocab(self._hparams.dataset.vocab_file,
                                bos_token=bos,
                                eos_token=eos)
        else:
            self._vocab = vocab

        # Create embedding
        if embedding is not None:
            self._embedding = self.make_embedding(
                self._hparams.dataset.embedding_init,
                self._vocab.token_to_id_map_py)
        else:
            self._embedding = embedding

        self._delimiter = self._hparams.dataset.delimiter
        self._max_seq_length = self._hparams.dataset.max_seq_length
        self._length_filter_mode = _LengthFilterMode(
            self._hparams.dataset.length_filter_mode)
        self._pad_length = self._max_seq_length
        if self._pad_length is not None:
            self._pad_length += sum(
                int(x != '') for x in [self._bos_token, self._eos_token])

        if data_source is None:
            if (self._length_filter_mode is _LengthFilterMode.DISCARD
                    and self._max_seq_length is not None):
                data_source = TextLineDataSource(
                    self._hparams.dataset.files,
                    compression_type=self._hparams.dataset.compression_type,
                    delimiter=self._delimiter,
                    max_length=self._max_seq_length)
            else:
                data_source = TextLineDataSource(
                    self._hparams.dataset.files,
                    compression_type=self._hparams.dataset.compression_type)

        super().__init__(data_source, hparams, device=device)
コード例 #2
0
    def make_vocab(hparams: List[HParams]) -> List[Optional[Vocab]]:
        r"""Makes a list of vocabs based on the hyperparameters.

        Args:
            hparams (list): A list of dataset hyperparameters.

        Returns:
            A list of :class:`texar.torch.data.Vocab` instances. Some instances
            may be the same objects if they are set to be shared and have
            the same other configurations.
        """
        vocabs: List[Optional[Vocab]] = []
        for i, hparams_i in enumerate(hparams):
            if not _is_text_data(hparams_i.data_type):
                vocabs.append(None)
                continue

            proc_share = hparams_i.processing_share_with
            if proc_share is not None:
                bos_token = hparams[proc_share].bos_token
                eos_token = hparams[proc_share].eos_token
            else:
                bos_token = hparams_i.bos_token
                eos_token = hparams_i.eos_token
            bos_token = utils.default_str(bos_token, SpecialTokens.BOS)
            eos_token = utils.default_str(eos_token, SpecialTokens.EOS)

            vocab_share = hparams_i.vocab_share_with
            if vocab_share is not None:
                if vocab_share >= i:
                    MultiAlignedData._raise_sharing_error(
                        i, vocab_share, "vocab_share_with")
                if vocabs[vocab_share] is None:
                    raise ValueError(
                        f"Cannot share vocab with dataset {vocab_share} which "
                        "does not have a vocab.")
                if (bos_token == vocabs[vocab_share].bos_token
                        and eos_token == vocabs[vocab_share].eos_token):
                    vocab = vocabs[vocab_share]
                else:
                    vocab = Vocab(hparams[vocab_share].vocab_file,
                                  bos_token=bos_token,
                                  eos_token=eos_token)
            else:
                vocab = Vocab(hparams_i.vocab_file,
                              bos_token=bos_token,
                              eos_token=eos_token)
            vocabs.append(vocab)

        return vocabs
コード例 #3
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(
                int(x is not None and x != '')
                for x in [src_hparams.bos_token, src_hparams.eos_token])

        src_data_source = TextLineDataSource(
            src_hparams.files, compression_type=src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(
                int(x is not None and x != '')
                for x in [tgt_hparams.bos_token, tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(
            tgt_hparams.files, compression_type=tgt_hparams.compression_type)

        data_source: DataSource[Tuple[List[str], List[str]]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if ((self._src_length_filter_mode is _LengthFilterMode.DISCARD
             and self._src_max_seq_length is not None)
                or (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD
                    and self._tgt_length_filter_mode is not None)):
            max_source_length = self._src_max_seq_length or math.inf
            max_tgt_length = self._tgt_max_seq_length or math.inf

            def filter_fn(raw_example):
                return (len(raw_example[0]) <= max_source_length
                        and len(raw_example[1]) <= max_tgt_length)

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)