Example #1
0
def load_parallel_data(src_file, tgt_file, batch_size, sort_k_batches, dictionary, training=False):
    def preproc(s):
        s = s.replace('``', '"')
        s = s.replace('\'\'', '"')
        return s
    enc_dset = TextFile(files=[src_file], dictionary=dictionary,
            bos_token=None, eos_token=None, unk_token=CHAR_UNK_TOK, level='character', preprocess=preproc)
    dec_dset = TextFile(files=[tgt_file], dictionary=dictionary,
            bos_token=CHAR_SOS_TOK, eos_token=CHAR_EOS_TOK, unk_token=CHAR_UNK_TOK, level='character', preprocess=preproc)
    # NOTE merge encoder and decoder setup together
    stream = Merge([enc_dset.get_example_stream(), dec_dset.get_example_stream()],
            ('source', 'target'))
    if training:
        # filter sequences that are too long
        stream = Filter(stream, predicate=TooLong(seq_len=CHAR_MAX_SEQ_LEN))
        # batch and read k batches ahead
        stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size*sort_k_batches))
        # sort all samples in read-ahead batch
        stream = Mapping(stream, SortMapping(lambda x: len(x[1])))
        # turn back into stream
        stream = Unpack(stream)
    # batch again
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
    masked_stream = Padding(stream)
    return masked_stream
Example #2
0
    def __init__(self, files, vocabulary_size=None, min_count=None,
                 load_dir=None, skip_window=3, num_skips=4):
        if load_dir is not None:
            self.dictionary = self.load_dictionary(load_dir)
        else:
            if vocabulary_size is not None:
                dictionary_vocab = vocabulary_size - 1
            else:
                dictionary_vocab = None

            self.dictionary = self.make_dictionary(files,
                                                   dictionary_vocab,
                                                   min_count)

        self.vocab_size = len(self.dictionary)

        text_data = TextFile(files,
                             self.dictionary,
                             unk_token='<UNK>',
                             bos_token=None,
                             eos_token=None,
                             preprocess=self._preprocess)
        stream = DataStream(text_data)
        self.data_stream = SkipGram(skip_window=skip_window,
                                    num_skips=num_skips,
                                    data_stream=stream)
Example #3
0
def load_parallel_data(src_file,
                       tgt_file,
                       batch_size,
                       sort_k_batches,
                       dictionary,
                       training=False):
    def preproc(s):
        s = s.replace('``', '"')
        s = s.replace('\'\'', '"')
        return s

    enc_dset = TextFile(files=[src_file],
                        dictionary=dictionary,
                        bos_token=None,
                        eos_token=None,
                        unk_token=CHAR_UNK_TOK,
                        level='character',
                        preprocess=preproc)
    dec_dset = TextFile(files=[tgt_file],
                        dictionary=dictionary,
                        bos_token=CHAR_SOS_TOK,
                        eos_token=CHAR_EOS_TOK,
                        unk_token=CHAR_UNK_TOK,
                        level='character',
                        preprocess=preproc)
    # NOTE merge encoder and decoder setup together
    stream = Merge(
        [enc_dset.get_example_stream(),
         dec_dset.get_example_stream()], ('source', 'target'))
    if training:
        # filter sequences that are too long
        stream = Filter(stream, predicate=TooLong(seq_len=CHAR_MAX_SEQ_LEN))
        # batch and read k batches ahead
        stream = Batch(stream,
                       iteration_scheme=ConstantScheme(batch_size *
                                                       sort_k_batches))
        # sort all samples in read-ahead batch
        stream = Mapping(stream, SortMapping(lambda x: len(x[1])))
        # turn back into stream
        stream = Unpack(stream)
    # batch again
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
    masked_stream = Padding(stream)
    return masked_stream
Example #4
0
def load_data(src_file, tgt_file, batch_size, sort_k_batches, training=False):
    src_dict, tgt_dict = load_dictionaries()

    src_dset = TextFile(files=[src_file], dictionary=src_dict,
            bos_token=None, eos_token=None, unk_token=WORD_UNK_TOK)
    tgt_dset = TextFile(files=[tgt_file], dictionary=tgt_dict,
            bos_token=WORD_EOS_TOK, eos_token=WORD_EOS_TOK, unk_token=WORD_UNK_TOK)

    stream = Merge([src_dset.get_example_stream(), tgt_dset.get_example_stream()],
            ('source', 'target'))
    # filter sequences that are too long
    if training:
        stream = Filter(stream, predicate=TooLong(seq_len=WORD_MAX_SEQ_LEN))
        # batch and read k batches ahead
        stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size*sort_k_batches))
        # sort all samples in read-ahead batch
        stream = Mapping(stream, SortMapping(lambda x: len(x[1])))
        # turn back into stream
        stream = Unpack(stream)
    # batch again
    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
    # NOTE pads with zeros so eos_idx should be 0
    masked_stream = Padding(stream)
    return masked_stream, src_dict, tgt_dict
Example #5
0
def get_stream(source,
               target,
               source_input_dict,
               target_label_dict,
               batch_size,
               buffer_multiplier=100,
               input_token_level='word',
               n_input_tokens=0,
               n_labels=0,
               reverse_labels=False,
               max_input_length=None,
               max_label_length=None,
               pad_labels=True,
               is_sort=True):
    """Returns a stream over sentence pairs.

    Parameters
    ----------
    source : list
        A list of files to read source languages from.
    target : list
        A list of corresponding files in the target language.
    source_word_dict : str
        Path to a tab-delimited text file whose last column contains the
        vocabulary.
    target_label_dict : str
        See `source_char_dict`.
    batch_size : int
        The minibatch size.
    buffer_multiplier : int
        The number of batches to load, concatenate, sort by length of
        source sentence, and split again; this makes batches more uniform
        in their sentence length and hence more computationally efficient.
    n_source_words : int
        The number of words in the source vocabulary. Pass 0 (default) to
        use the entire vocabulary.
    n_target_labels : int
        See `n_chars_source`.

    """
    if len(source) != len(target):
        raise ValueError("number of source and target files don't match")

    # Read the dictionaries
    dicts = [
        load_dict(source_input_dict, dict_size=n_input_tokens),
        load_dict(target_label_dict,
                  dict_size=n_labels,
                  reverse=reverse_labels,
                  include_unk=False)
    ]

    # Open the two sets of files and merge them
    streams = [
        TextFile(source,
                 dicts[0],
                 level=input_token_level,
                 bos_token=None,
                 eos_token=EOS_TOKEN,
                 encoding='utf-8').get_example_stream(),
        TextFile(target,
                 dicts[1],
                 level='word',
                 bos_token=None,
                 unk_token=None,
                 eos_token=EOS_TOKEN,
                 encoding='utf-8').get_example_stream()
    ]
    merged = Merge(streams, ('source_input_tokens', 'target_labels'))
    if reverse_labels:
        merged = SortLabels(merged)

    # Filter sentence lengths
    if max_input_length or max_label_length:

        def filter_pair(pair):
            src_input_tokens, trg_labels = pair
            src_input_ok = (not max_input_length) or \
                len(src_input_tokens) <= (max_input_length + 1)
            trg_label_ok = (not max_label_length) or \
                len(trg_labels) <= (max_label_length + 1)

            return src_input_ok and trg_label_ok

        merged = Filter(merged, filter_pair)

    # Batches of approximately uniform size
    large_batches = Batch(merged,
                          iteration_scheme=ConstantScheme(batch_size *
                                                          buffer_multiplier))
    # sorted_batches = Mapping(large_batches, SortMapping(_source_length))
    # batches = Cache(sorted_batches, ConstantScheme(batch_size))
    # shuffled_batches = Shuffle(batches, buffer_multiplier)
    # masked_batches = Padding(shuffled_batches,
    #                          mask_sources=('source_chars', 'target_labels'))
    if is_sort:
        sorted_batches = Mapping(large_batches, SortMapping(_source_length))
    else:
        sorted_batches = large_batches
    batches = Cache(sorted_batches, ConstantScheme(batch_size))
    mask_sources = ('source_input_tokens', 'target_labels')
    masked_batches = Padding(batches, mask_sources=mask_sources)

    return masked_batches
def get_stream(source,
               target,
               source_dict,
               target_dict,
               batch_size,
               buffer_multiplier=100,
               n_words_source=0,
               n_words_target=0,
               max_src_length=None,
               max_trg_length=None):
    """Returns a stream over sentence pairs.

    Parameters
    ----------
    source : list
        A list of files to read source languages from.
    target : list
        A list of corresponding files in the target language.
    source_dict : str
        Path to a tab-delimited text file whose last column contains the
        vocabulary.
    target_dict : str
        See `source_dict`.
    batch_size : int
        The minibatch size.
    buffer_multiplier : int
        The number of batches to load, concatenate, sort by length of
        source sentence, and split again; this makes batches more uniform
        in their sentence length and hence more computationally efficient.
    n_words_source : int
        The number of words in the source vocabulary. Pass 0 (default) to
        use the entire vocabulary.
    n_words_target : int
        See `n_words_source`.

    """
    if len(source) != len(target):
        raise ValueError("number of source and target files don't match")

    # Read the dictionaries
    dicts = [
        load_dict(source_dict, n_words=n_words_source),
        load_dict(target_dict, n_words=n_words_target)
    ]

    # Open the two sets of files and merge them
    streams = [
        TextFile(source, dicts[0], bos_token=None,
                 eos_token=EOS_TOKEN).get_example_stream(),
        TextFile(target, dicts[1], bos_token=None,
                 eos_token=EOS_TOKEN).get_example_stream()
    ]
    merged = Merge(streams, ('source', 'target'))

    # Filter sentence lengths
    if max_src_length or max_trg_length:

        def filter_pair(pair):
            src, trg = pair
            src_ok = (not max_src_length) or len(src) < max_src_length
            trg_ok = (not max_trg_length) or len(trg) < max_trg_length
            return src_ok and trg_ok

        merged = Filter(merged, filter_pair)

    # Batches of approximately uniform size
    large_batches = Batch(merged,
                          iteration_scheme=ConstantScheme(batch_size *
                                                          buffer_multiplier))
    sorted_batches = Mapping(large_batches, SortMapping(_source_length))
    batches = Cache(sorted_batches, ConstantScheme(batch_size))
    shuffled_batches = Shuffle(batches, buffer_multiplier)
    masked_batches = Padding(shuffled_batches)

    return masked_batches