Beispiel #1
0
    def __init__(self,
                 path,
                 src_max_len,
                 tgt_max_len,
                 src_vocab=None,
                 tgt_vocab=None,
                 src_max_vocab=50000,
                 tgt_max_vocab=50000):
        # Prepare data
        self.src_max_len = src_max_len
        self.tgt_max_len = tgt_max_len
        # debug_here()
        pairs = utils.prepare_data(path, src_max_len, tgt_max_len)

        # pairs[0]: sample 1, pairs[0][0]: src, pairs[0][1]: tgt
        # Read in vocabularies
        # zip object is not subscritable in python3, so

        # all the srcs:             pairs_zipped[0][0], pairs_zipped[0][1], pairs_zipped[0][2] ...
        # corresp. to all the tgs:  pairs_zipped[1][0], pairs_zipped[1][1], pairs_zipped[1][2] ...
        pairs_zipped = list(zip(*pairs))

        self.input_vocab = self._init_vocab(pairs_zipped[0], src_max_vocab,
                                            src_vocab)
        self.output_vocab = self._init_vocab(pairs_zipped[1], tgt_max_vocab,
                                             tgt_vocab)

        # Translate input sequences to token ids
        self.data = []
        for pair in pairs:
            src = self.input_vocab.indices_from_sequence(pair[0])
            dst = self.output_vocab.indices_from_sequence(pair[1])
            self.data.append((src, dst))
Beispiel #2
0
    def from_file(cls,
                  path,
                  src_max_len,
                  tgt_max_len,
                  src_vocab=None,
                  tgt_vocab=None,
                  src_max_vocab=50000,
                  tgt_max_vocab=50000):
        """
        Initialize a dataset from the file at given path. The file
        must contains a list of TAB-separated pairs of sequences.

        Note:
            Source or target sequences that are longer than the respective
            max length will be filtered.
            As specified by maximum vocabulary size, source and target
            vocabularies will be sorted in descending token frequency and cutoff.
            Tokens that are in the dataset but not retained in the vocabulary
            will be dropped in the sequences.

        Args:
            path (str): path to the dataset file
            src_max_len (int): maximum source sequence length
            tgt_max_len (int): maximum target sequence length
            src_vocab (Vocabulary): pre-populated Vocabulary object or a path of a file containing words for the source language,
            default `None`. If a pre-populated Vocabulary object, `src_max_vocab` wouldn't be used.
            tgt_vocab (Vocabulary): pre-populated Vocabulary object or a path of a file containing words for the target language,
            default `None`. If a pre-populated Vocabulary object, `tgt_max_vocab` wouldn't be used.
            src_max_vocab (int): maximum source vocabulary size
            tgt_max_vocab (int): maximum target vocabulary size
        """
        obj = cls(src_max_len, tgt_max_len)
        pairs = utils.prepare_data(path, src_max_len, tgt_max_len)
        return cls._encode(obj, pairs, src_vocab, tgt_vocab, src_max_vocab,
                           tgt_max_vocab)
Beispiel #3
0
    def __init__(self,
                 path,
                 src_max_len,
                 tgt_max_len,
                 src_vocab=None,
                 tgt_vocab=None,
                 src_max_vocab=50000,
                 tgt_max_vocab=50000):
        # Prepare data
        self.src_max_len = src_max_len
        self.tgt_max_len = tgt_max_len
        pairs = utils.prepare_data(path, src_max_len, tgt_max_len)

        # Read in vocabularies
        self.input_vocab = self._init_vocab(
            zip(*pairs)[0], src_max_vocab, src_vocab)
        self.output_vocab = self._init_vocab(
            zip(*pairs)[1], tgt_max_vocab, tgt_vocab)

        # Translate input sequences to token ids
        self.data = []
        for pair in pairs:
            src = self.input_vocab.indices_from_sequence(pair[0])
            dst = self.output_vocab.indices_from_sequence(pair[1])
            self.data.append((src, dst))
 def test_prepare_data_WITH_VALID_PATH(self):
     pairs = utils.prepare_data("./tests/data/eng-fra.txt", 20, 20)
     self.assertEqual(100, len(pairs))