Exemplo n.º 1
0
    def make_embedding(src_emb_hparams, src_token_to_id_map,
                       tgt_emb_hparams=None, tgt_token_to_id_map=None,
                       emb_init_share=False):
        r"""Optionally loads source and target embeddings from files
        (if provided), and returns respective :class:`texar.data.Embedding`
        instances.
        """
        src_embedding = MonoTextData.make_embedding(src_emb_hparams,
                                                    src_token_to_id_map)

        if emb_init_share:
            tgt_embedding = src_embedding
        else:
            tgt_emb_file = tgt_emb_hparams["file"]
            tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams)

        return src_embedding, tgt_embedding
Exemplo n.º 2
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(int(x is not None and x != '')
                                        for x in [src_hparams.bos_token,
                                                  src_hparams.eos_token])

        src_data_source = TextLineDataSource(src_hparams.files,
                                             compression_type=
                                             src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(int(x is not None and x != '')
                                        for x in [tgt_hparams.bos_token,
                                                  tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(tgt_hparams.files,
                                             compression_type=
                                             tgt_hparams.compression_type)

        data_source: DataSource[Tuple[str, str]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and
            self._src_max_seq_length is not None) or \
                (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and
                 self._tgt_length_filter_mode is not None):
            max_source_length = self._src_max_seq_length if \
                self._src_max_seq_length is not None else np.inf
            max_tgt_length = self._tgt_max_seq_length if \
                self._tgt_max_seq_length is not None else np.inf

            def filter_fn(raw_example):
                return len(raw_example[0].split(self._src_delimiter)) \
                       <= max_source_length and \
                       len(raw_example[1].split(self._tgt_delimiter)) \
                       <= max_tgt_length

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)