示例#1
0
 def make_vocab(hparams):
     """Reads vocab file and returns an instance of
     :class:`texar.tf.data.Vocab`.
     """
     bos_token = utils.default_str(hparams["bos_token"], SpecialTokens.BOS)
     eos_token = utils.default_str(hparams["eos_token"], SpecialTokens.EOS)
     vocab = Vocab(hparams["vocab_file"],
                   bos_token=bos_token,
                   eos_token=eos_token)
     return vocab
示例#2
0
    def make_vocab(hparams):
        """Makes a list of vocabs based on the hparams.

        Args:
            hparams (list): A list of dataset hyperparameters.

        Returns:
            A list of :class:`texar.tf.data.Vocab` instances. Some instances
            may be the same objects if they are set to be shared and have
            the same other configs.
        """
        if not isinstance(hparams, (list, tuple)):
            hparams = [hparams]

        vocabs = []
        for i, hparams_i in enumerate(hparams):
            if not _is_text_data(hparams_i["data_type"]):
                vocabs.append(None)
                continue

            proc_shr = hparams_i["processing_share_with"]
            if proc_shr is not None:
                bos_token = hparams[proc_shr]["bos_token"]
                eos_token = hparams[proc_shr]["eos_token"]
            else:
                bos_token = hparams_i["bos_token"]
                eos_token = hparams_i["eos_token"]
            bos_token = utils.default_str(
                bos_token, SpecialTokens.BOS)
            eos_token = utils.default_str(
                eos_token, SpecialTokens.EOS)

            vocab_shr = hparams_i["vocab_share_with"]
            if vocab_shr is not None:
                if vocab_shr >= i:
                    MultiAlignedData._raise_sharing_error(
                        i, vocab_shr, "vocab_share_with")
                if not vocabs[vocab_shr]:
                    raise ValueError("Cannot share vocab with dataset %d which "
                                     "does not have a vocab." % vocab_shr)
                if bos_token == vocabs[vocab_shr].bos_token and \
                        eos_token == vocabs[vocab_shr].eos_token:
                    vocab = vocabs[vocab_shr]
                else:
                    vocab = Vocab(hparams[vocab_shr]["vocab_file"],
                                  bos_token=bos_token,
                                  eos_token=eos_token)
            else:
                vocab = Vocab(hparams_i["vocab_file"],
                              bos_token=bos_token,
                              eos_token=eos_token)
            vocabs.append(vocab)

        return vocabs
示例#3
0
    def make_vocab(src_hparams, tgt_hparams):
        """Reads vocab files and returns source vocab and target vocab.

        Args:
            src_hparams (dict or HParams): Hyperparameters of source dataset.
            tgt_hparams (dict or HParams): Hyperparameters of target dataset.

        Returns:
            A pair of :class:`texar.tf.data.Vocab` instances. The two instances
            may be the same objects if source and target vocabs are shared
            and have the same other configs.
        """
        src_vocab = MonoTextData.make_vocab(src_hparams)

        if tgt_hparams["processing_share"]:
            tgt_bos_token = src_hparams["bos_token"]
            tgt_eos_token = src_hparams["eos_token"]
        else:
            tgt_bos_token = tgt_hparams["bos_token"]
            tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == src_vocab.bos_token and \
                    tgt_eos_token == src_vocab.eos_token:
                tgt_vocab = src_vocab
            else:
                tgt_vocab = Vocab(src_hparams["vocab_file"],
                                  bos_token=tgt_bos_token,
                                  eos_token=tgt_eos_token)
        else:
            tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                              bos_token=tgt_bos_token,
                              eos_token=tgt_eos_token)

        return src_vocab, tgt_vocab