def make_vocab(hparams): """Reads vocab file and returns an instance of :class:`texar.data.Vocab`. """ bos_token = utils.default_str(hparams["bos_token"], SpecialTokens.BOS) eos_token = utils.default_str(hparams["eos_token"], SpecialTokens.EOS) vocab = Vocab(hparams["vocab_file"], bos_token=bos_token, eos_token=eos_token) return vocab
def make_vocab(hparams): r"""Makes a list of vocabs based on the hyperparameters. Args: hparams (list): A list of dataset hyperparameters. Returns: A list of :class:`texar.data.Vocab` instances. Some instances may be the same objects if they are set to be shared and have the same other configurations. """ if not isinstance(hparams, (list, tuple)): hparams = [hparams] vocabs = [] for i, hparams_i in enumerate(hparams): if not _is_text_data(hparams_i["data_type"]): vocabs.append(None) continue proc_share = hparams_i["processing_share_with"] if proc_share is not None: bos_token = hparams[proc_share]["bos_token"] eos_token = hparams[proc_share]["eos_token"] else: bos_token = hparams_i["bos_token"] eos_token = hparams_i["eos_token"] bos_token = utils.default_str( bos_token, SpecialTokens.BOS) eos_token = utils.default_str( eos_token, SpecialTokens.EOS) vocab_share = hparams_i["vocab_share_with"] if vocab_share is not None: if vocab_share >= i: MultiAlignedData._raise_sharing_error( i, vocab_share, "vocab_share_with") if vocabs[vocab_share] is None: raise ValueError("Cannot share vocab with dataset %d which " "does not have a vocab." % vocab_share) if bos_token == vocabs[vocab_share].bos_token and \ eos_token == vocabs[vocab_share].eos_token: vocab = vocabs[vocab_share] else: vocab = Vocab(hparams[vocab_share]["vocab_file"], bos_token=bos_token, eos_token=eos_token) else: vocab = Vocab(hparams_i["vocab_file"], bos_token=bos_token, eos_token=eos_token) vocabs.append(vocab) return vocabs
def _construct(cls, hparams, device: Optional[torch.device] = None, vocab: Optional[Vocab] = None, embedding: Optional[Vocab] = None): mono_text_data = cls.__new__(cls) mono_text_data._hparams = HParams(hparams, mono_text_data.default_hparams()) if mono_text_data._hparams.dataset.variable_utterance: raise NotImplementedError dataset = mono_text_data._hparams.dataset mono_text_data._other_transforms = dataset.other_transformations # Create vocabulary if vocab is not None: mono_text_data._vocab = vocab mono_text_data._bos_token = vocab.bos_token mono_text_data._eos_token = vocab.eos_token else: mono_text_data._bos_token = dataset.bos_token mono_text_data._eos_token = dataset.eos_token bos = utils.default_str(mono_text_data._bos_token, SpecialTokens.BOS) eos = utils.default_str(mono_text_data._eos_token, SpecialTokens.EOS) mono_text_data._vocab = Vocab(dataset.vocab_file, bos_token=bos, eos_token=eos) # Create embedding if embedding is not None: mono_text_data._embedding = embedding else: mono_text_data._embedding = mono_text_data.make_embedding( dataset.embedding_init, mono_text_data._vocab.token_to_id_map_py) mono_text_data._delimiter = dataset.delimiter mono_text_data._max_seq_length = dataset.max_seq_length mono_text_data._length_filter_mode = _LengthFilterMode( mono_text_data._hparams.dataset.length_filter_mode) mono_text_data._pad_length = mono_text_data._max_seq_length if mono_text_data._pad_length is not None: mono_text_data._pad_length += sum( int(x != '') for x in [mono_text_data._bos_token, mono_text_data._eos_token]) data_source: SequenceDataSource[str] = SequenceDataSource([]) super(MonoTextData, mono_text_data).__init__(source=data_source, hparams=hparams, device=device) return mono_text_data
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) if self._hparams.dataset.variable_utterance: raise NotImplementedError # Create vocabulary self._bos_token = self._hparams.dataset.bos_token self._eos_token = self._hparams.dataset.eos_token self._other_transforms = self._hparams.dataset.other_transformations bos = utils.default_str(self._bos_token, SpecialTokens.BOS) eos = utils.default_str(self._eos_token, SpecialTokens.EOS) self._vocab = Vocab(self._hparams.dataset.vocab_file, bos_token=bos, eos_token=eos) # Create embedding self._embedding = self.make_embedding( self._hparams.dataset.embedding_init, self._vocab.token_to_id_map_py) self._delimiter = self._hparams.dataset.delimiter self._max_seq_length = self._hparams.dataset.max_seq_length self._length_filter_mode = _LengthFilterMode( self._hparams.dataset.length_filter_mode) self._pad_length = self._max_seq_length if self._pad_length is not None: self._pad_length += sum( int(x != '') for x in [self._bos_token, self._eos_token]) if (self._length_filter_mode is _LengthFilterMode.DISCARD and self._max_seq_length is not None): data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type, delimiter=self._delimiter, max_length=self._max_seq_length) else: data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type) super().__init__(data_source, hparams, device=device)
def make_vocab(src_hparams, tgt_hparams): """Reads vocab files and returns source vocab and target vocab. Args: src_hparams (dict or HParams): Hyperparameters of source dataset. tgt_hparams (dict or HParams): Hyperparameters of target dataset. Returns: A pair of :class:`texar.data.Vocab` instances. The two instances may be the same objects if source and target vocabs are shared and have the same other configs. """ src_vocab = MonoTextData.make_vocab(src_hparams) if tgt_hparams["processing_share"]: tgt_bos_token = src_hparams["bos_token"] tgt_eos_token = src_hparams["eos_token"] else: tgt_bos_token = tgt_hparams["bos_token"] tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == src_vocab.bos_token and \ tgt_eos_token == src_vocab.eos_token: tgt_vocab = src_vocab else: tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) return src_vocab, tgt_vocab
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) src_hparams = self.hparams.source_dataset tgt_hparams = self.hparams.target_dataset # create vocabulary self._src_bos_token = src_hparams["bos_token"] self._src_eos_token = src_hparams["eos_token"] self._src_transforms = src_hparams["other_transformations"] self._src_vocab = Vocab(src_hparams.vocab_file, bos_token=src_hparams.bos_token, eos_token=src_hparams.eos_token) if tgt_hparams["processing_share"]: self._tgt_bos_token = src_hparams["bos_token"] self._tgt_eos_token = src_hparams["eos_token"] else: self._tgt_bos_token = tgt_hparams["bos_token"] self._tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(self._tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(self._tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == self._src_vocab.bos_token and \ tgt_eos_token == self._src_vocab.eos_token: self._tgt_vocab = self._src_vocab else: self._tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: self._tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) # create embeddings self._src_embedding = MonoTextData.make_embedding( src_hparams.embedding_init, self._src_vocab.token_to_id_map_py) if self._hparams.target_dataset.embedding_init_share: self._tgt_embedding = self._src_embedding else: tgt_emb_file = tgt_hparams.embedding_init["file"] self._tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": self._tgt_embedding = MonoTextData.make_embedding( self._tgt_vocab.token_to_id_map_py, tgt_hparams.embedding_init) # create data source self._src_delimiter = src_hparams.delimiter self._src_max_seq_length = src_hparams.max_seq_length self._src_length_filter_mode = _LengthFilterMode( src_hparams.length_filter_mode) self._src_pad_length = self._src_max_seq_length if self._src_pad_length is not None: self._src_pad_length += sum(int(x is not None and x != '') for x in [src_hparams.bos_token, src_hparams.eos_token]) src_data_source = TextLineDataSource(src_hparams.files, compression_type= src_hparams.compression_type) self._tgt_transforms = tgt_hparams["other_transformations"] self._tgt_delimiter = tgt_hparams.delimiter self._tgt_max_seq_length = tgt_hparams.max_seq_length self._tgt_length_filter_mode = _LengthFilterMode( tgt_hparams.length_filter_mode) self._tgt_pad_length = self._tgt_max_seq_length if self._tgt_pad_length is not None: self._tgt_pad_length += sum(int(x is not None and x != '') for x in [tgt_hparams.bos_token, tgt_hparams.eos_token]) tgt_data_source = TextLineDataSource(tgt_hparams.files, compression_type= tgt_hparams.compression_type) data_source: DataSource[Tuple[str, str]] data_source = ZipDataSource( # type: ignore src_data_source, tgt_data_source) if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and self._src_max_seq_length is not None) or \ (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and self._tgt_length_filter_mode is not None): max_source_length = self._src_max_seq_length if \ self._src_max_seq_length is not None else np.inf max_tgt_length = self._tgt_max_seq_length if \ self._tgt_max_seq_length is not None else np.inf def filter_fn(raw_example): return len(raw_example[0].split(self._src_delimiter)) \ <= max_source_length and \ len(raw_example[1].split(self._tgt_delimiter)) \ <= max_tgt_length data_source = FilterDataSource(data_source, filter_fn) super().__init__(data_source, hparams, device=device)