def prepare_training_data(src_files, src_files_encoding, trg_files,
                          trg_files_encoding, src_output_file,
                          trg_output_file):
    '''
    for each pair of source/target files, check they have the same number of sentences;
    do shuffle and save with utf-8 encodings
    '''
    src = chain(
        *[iter_(open_(f, encoding=src_files_encoding)) for f in src_files])
    trg = chain(
        *[iter_(open_(f, encoding=trg_files_encoding)) for f in trg_files])

    # TODO: find a way not to load all sentences into memory
    logger.info("reading sentences from source files...")
    src_sentences = [sent for sent in src]
    logger.info("reading sentences from target files...")
    trg_sentences = [sent for sent in trg]

    assert len(src_sentences) == len(trg_sentences)
    logger.info("number of sentences:%d" % len(src_sentences))

    # '\n' has been removed from a sentence
    assert src_sentences[0].endswith('\n')
    # do the shuffle
    ids = list(range(len(src_sentences)))
    random.shuffle(ids)

    with codecs.open(src_output_file, 'w', 'UTF-8') as f_src:
        with codecs.open(trg_output_file, 'w', 'UTF-8') as f_trg:
            for i in ids:
                f_src.write(src_sentences[i])
                f_trg.write(trg_sentences[i])
Esempio n. 2
0
def _build_vocabulary(files,
                      encoding,
                      eos='</S>',
                      eos_id=0,
                      unk='<UNK>',
                      unk_id=1,
                      max_nb_of_vacabulary=None,
                      preprocess=to_lower_case):
    stat = {}
    for filepath in files:
        with open_(filepath, encoding=encoding) as f:
            for line in f:
                if preprocess is not None:
                    line = preprocess(line)
                words = line.split()
                # replace number with NUM
                for word in words:
                    if word in stat:
                        stat[word] += 1
                    else:
                        stat[word] = 1
    sorted_items = sorted(stat.items(), key=lambda d: d[1], reverse=True)
    if max_nb_of_vacabulary is not None:
        sorted_items = sorted_items[:max_nb_of_vacabulary]
    vocab = {}
    vocab[eos] = eos_id
    vocab[unk] = unk_id
    special_token_idxs = set([eos_id, unk_id])
    token_id = 0
    for token, _ in sorted_items:
        while token_id in special_token_idxs:
            token_id += 1
        vocab[token] = token_id
        token_id += 1
    return vocab
Esempio n. 3
0
def get_line_count(files, encoding='UTF-8'):
    count = 0
    for file_path in files:
        with open_(file_path, encoding=encoding) as f:
            for line in f:
                if line.rstrip('\n'):
                    count += 1
    return count
Esempio n. 4
0
File: text.py Progetto: Afrik/fuel
 def open(self):
     return chain(*[iter_(open_(f, encoding=self.encoding))
                    for f in self.files])
Esempio n. 5
0
 def open(self):
     return chain(
         *[iter_(open_(f, encoding=self.encoding)) for f in self.files])