def __init__(self, hparams, device: Optional[torch.device] = None, vocab: Optional[Vocab] = None, embedding: Optional[Embedding] = None, data_source: Optional[DataSource] = None): self._hparams = HParams(hparams, self.default_hparams()) if self._hparams.dataset.variable_utterance: raise NotImplementedError # Create vocabulary self._bos_token = self._hparams.dataset.bos_token self._eos_token = self._hparams.dataset.eos_token self._other_transforms = self._hparams.dataset.other_transformations bos = utils.default_str(self._bos_token, SpecialTokens.BOS) eos = utils.default_str(self._eos_token, SpecialTokens.EOS) if vocab is None: self._vocab = Vocab(self._hparams.dataset.vocab_file, bos_token=bos, eos_token=eos) else: self._vocab = vocab # Create embedding if embedding is not None: self._embedding = self.make_embedding( self._hparams.dataset.embedding_init, self._vocab.token_to_id_map_py) else: self._embedding = embedding self._delimiter = self._hparams.dataset.delimiter self._max_seq_length = self._hparams.dataset.max_seq_length self._length_filter_mode = _LengthFilterMode( self._hparams.dataset.length_filter_mode) self._pad_length = self._max_seq_length if self._pad_length is not None: self._pad_length += sum( int(x != '') for x in [self._bos_token, self._eos_token]) if data_source is None: if (self._length_filter_mode is _LengthFilterMode.DISCARD and self._max_seq_length is not None): data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type, delimiter=self._delimiter, max_length=self._max_seq_length) else: data_source = TextLineDataSource( self._hparams.dataset.files, compression_type=self._hparams.dataset.compression_type) super().__init__(data_source, hparams, device=device)
def make_vocab(hparams: List[HParams]) -> List[Optional[Vocab]]: r"""Makes a list of vocabs based on the hyperparameters. Args: hparams (list): A list of dataset hyperparameters. Returns: A list of :class:`texar.torch.data.Vocab` instances. Some instances may be the same objects if they are set to be shared and have the same other configurations. """ vocabs: List[Optional[Vocab]] = [] for i, hparams_i in enumerate(hparams): if not _is_text_data(hparams_i.data_type): vocabs.append(None) continue proc_share = hparams_i.processing_share_with if proc_share is not None: bos_token = hparams[proc_share].bos_token eos_token = hparams[proc_share].eos_token else: bos_token = hparams_i.bos_token eos_token = hparams_i.eos_token bos_token = utils.default_str(bos_token, SpecialTokens.BOS) eos_token = utils.default_str(eos_token, SpecialTokens.EOS) vocab_share = hparams_i.vocab_share_with if vocab_share is not None: if vocab_share >= i: MultiAlignedData._raise_sharing_error( i, vocab_share, "vocab_share_with") if vocabs[vocab_share] is None: raise ValueError( f"Cannot share vocab with dataset {vocab_share} which " "does not have a vocab.") if (bos_token == vocabs[vocab_share].bos_token and eos_token == vocabs[vocab_share].eos_token): vocab = vocabs[vocab_share] else: vocab = Vocab(hparams[vocab_share].vocab_file, bos_token=bos_token, eos_token=eos_token) else: vocab = Vocab(hparams_i.vocab_file, bos_token=bos_token, eos_token=eos_token) vocabs.append(vocab) return vocabs
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) src_hparams = self.hparams.source_dataset tgt_hparams = self.hparams.target_dataset # create vocabulary self._src_bos_token = src_hparams["bos_token"] self._src_eos_token = src_hparams["eos_token"] self._src_transforms = src_hparams["other_transformations"] self._src_vocab = Vocab(src_hparams.vocab_file, bos_token=src_hparams.bos_token, eos_token=src_hparams.eos_token) if tgt_hparams["processing_share"]: self._tgt_bos_token = src_hparams["bos_token"] self._tgt_eos_token = src_hparams["eos_token"] else: self._tgt_bos_token = tgt_hparams["bos_token"] self._tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(self._tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(self._tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == self._src_vocab.bos_token and \ tgt_eos_token == self._src_vocab.eos_token: self._tgt_vocab = self._src_vocab else: self._tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: self._tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) # create embeddings self._src_embedding = MonoTextData.make_embedding( src_hparams.embedding_init, self._src_vocab.token_to_id_map_py) if self._hparams.target_dataset.embedding_init_share: self._tgt_embedding = self._src_embedding else: tgt_emb_file = tgt_hparams.embedding_init["file"] self._tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": self._tgt_embedding = MonoTextData.make_embedding( self._tgt_vocab.token_to_id_map_py, tgt_hparams.embedding_init) # create data source self._src_delimiter = src_hparams.delimiter self._src_max_seq_length = src_hparams.max_seq_length self._src_length_filter_mode = _LengthFilterMode( src_hparams.length_filter_mode) self._src_pad_length = self._src_max_seq_length if self._src_pad_length is not None: self._src_pad_length += sum( int(x is not None and x != '') for x in [src_hparams.bos_token, src_hparams.eos_token]) src_data_source = TextLineDataSource( src_hparams.files, compression_type=src_hparams.compression_type) self._tgt_transforms = tgt_hparams["other_transformations"] self._tgt_delimiter = tgt_hparams.delimiter self._tgt_max_seq_length = tgt_hparams.max_seq_length self._tgt_length_filter_mode = _LengthFilterMode( tgt_hparams.length_filter_mode) self._tgt_pad_length = self._tgt_max_seq_length if self._tgt_pad_length is not None: self._tgt_pad_length += sum( int(x is not None and x != '') for x in [tgt_hparams.bos_token, tgt_hparams.eos_token]) tgt_data_source = TextLineDataSource( tgt_hparams.files, compression_type=tgt_hparams.compression_type) data_source: DataSource[Tuple[List[str], List[str]]] data_source = ZipDataSource( # type: ignore src_data_source, tgt_data_source) if ((self._src_length_filter_mode is _LengthFilterMode.DISCARD and self._src_max_seq_length is not None) or (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and self._tgt_length_filter_mode is not None)): max_source_length = self._src_max_seq_length or math.inf max_tgt_length = self._tgt_max_seq_length or math.inf def filter_fn(raw_example): return (len(raw_example[0]) <= max_source_length and len(raw_example[1]) <= max_tgt_length) data_source = FilterDataSource(data_source, filter_fn) super().__init__(data_source, hparams, device=device)