示例#1
0
def count_vocab(data_file, max_vcb_size, max_seq_len=50, char=False):

    assert data_file and os.path.exists(
        data_file), 'need file to extract vocabulary ...'

    vocab = Vocab()
    #with open(data_file, 'r') as f:
    with io.open(data_file, encoding='utf-8') as f:
        for sent in f.readlines():
            #sent = sent.strip().encode('utf-8')
            sent = sent.strip()
            if char is True: words = zh_to_chars(sent)
            else: words = sent.split()
            if len(words) > max_seq_len: continue
            for word in words:
                vocab.add(word)

    words_cnt = sum(vocab.freq.values())
    new_vocab, new_words_cnt = vocab.keep_vocab_size(max_vcb_size)
    wlog('|Final vocabulary| / |Original vocabulary| = {} / {} = {:4.2f}%'.
         format(new_words_cnt, words_cnt, (new_words_cnt / words_cnt) * 100))

    new_vocab.idx2key = {k: str(v) for k, v in new_vocab.idx2key.items()}
    new_vocab.key2idx = {str(k): v for k, v in new_vocab.key2idx.items()}

    return new_vocab
示例#2
0
def wrap_tst_data(src_data, src_vocab, char=False):

    srcs, slens = [], []
    srcF = io.open(src_data, mode='r', encoding='utf-8')
    idx = 0

    while True:

        src_sent = srcF.readline()
        if src_sent == '':
            wlog('\nFinish to read monolingual test dataset {}, count {}'.format(src_data, idx))
            break
        idx += 1

        if src_sent == '':
            wlog('Error. Ignore abnormal blank sentence in line number {}'.format(idx))
            sys.exit(0)

        src_sent = src_sent.strip()
        if char is True: src_sent = ' '.join(zh_to_chars(src_sent))
        src_words = src_sent.split()
        src_len = len(src_words)
        srcs.append(src_vocab.keys2idx(src_words, UNK_WORD))

        slens.append(src_len)

    srcF.close()
    return srcs, slens
示例#3
0
def wrap_data(data_dir, file_prefix, src_suffix, trg_prefix, src_vocab, trg_vocab,
              shuffle=True, sort_data=True, max_seq_len=50, char=False):

    srcF = open(os.path.join(data_dir, '{}.{}'.format(file_prefix, src_suffix)), 'r')
    num = len(srcF.readlines())
    srcF.close()
    point_every, number_every = int(math.ceil(num/100)), int(math.ceil(num/10))

    srcF = io.open(os.path.join(data_dir, '{}.{}'.format(file_prefix, src_suffix)),
                   mode='r', encoding='utf-8')

    trgFs = []  # maybe have multi-references for valid, we open them together
    for fname in os.listdir(data_dir):
        if fname.startswith('{}.{}'.format(file_prefix, trg_prefix)):
            wlog('\t{}'.format(os.path.join(data_dir, fname)))
            trgFs.append(open(os.path.join(data_dir, fname), 'r'))
    wlog('NOTE: Target side has {} references.'.format(len(trgFs)))

    idx, ignore, longer = 0, 0, 0
    srcs, trgs, slens = [], [], []
    while True:

        src_sent = srcF.readline().strip()
        if char is True: src_sent = ' '.join(zh_to_chars(src_sent))
        trg_refs = [trgF.readline().strip() for trgF in trgFs]

        if src_sent == '' and all([trg_ref == '' for trg_ref in trg_refs]):
            wlog('\nFinish to read bilingual corpus.')
            break

        if numpy.mod(idx + 1, point_every) == 0: wlog('.', False)
        if numpy.mod(idx + 1, number_every) == 0: wlog('{}'.format(idx + 1), False)
        idx += 1

        if src_sent == '' or any([trg_ref == '' for trg_ref in trg_refs]):
            wlog('Ignore abnormal blank sentence in line number {}'.format(idx))
            ignore += 1
            continue

        src_words = src_sent.split()
        src_len = len(src_words)
        trg_refs_words = [trg_ref.split() for trg_ref in trg_refs]
        if src_len <= max_seq_len and all([len(tws) <= max_seq_len for tws in trg_refs_words]):
            src_tensor = src_vocab.keys2idx(src_words, UNK_WORD)
            trg_refs_tensor = [trg_vocab.keys2idx(trg_ref_words, UNK_WORD,
                                              bos_word=BOS_WORD, eos_word=EOS_WORD)
                               for trg_ref_words in trg_refs_words]
            srcs.append(src_tensor)
            trgs.append(trg_refs_tensor)
            slens.append(src_len)
        else:
            longer += 1

    srcF.close()
    for trgF in trgFs: trgF.close()

    train_size = len(srcs)
    assert train_size == idx - ignore - longer, 'Wrong .. '
    wlog('Sentence-pairs count: {}(total) - {}(ignore) - {}(longer) = {}'.format(
        idx, ignore, longer, idx - ignore - longer))

    if shuffle is True:

        #assert len(trgFs) == 1, 'Unsupport to shuffle validation set.'
        wlog('Shuffling the whole dataset ... ', False)
        rand_idxs = tc.randperm(train_size).tolist()
        srcs = [srcs[k] for k in rand_idxs]
        trgs = [trgs[k] for k in rand_idxs]
        slens = [slens[k] for k in rand_idxs]

    final_srcs, final_trgs = srcs, trgs

    if sort_data is True:

        #assert len(trgFs) == 1, 'Unsupport to sort validation set in k batches.'
        final_srcs, final_trgs = [], []

        if wargs.sort_k_batches == 0:
            wlog('Sorting the whole dataset by ascending order of source length ... ', False)
            # sort the whole training data by ascending order of source length
            _, sorted_idx = tc.sort(tc.IntTensor(slens))
            final_srcs = [srcs[k] for k in sorted_idx]
            final_trgs = [trgs[k] for k in sorted_idx]
        else:
            wlog('Sorting for each {} batches ... '.format(wargs.sort_k_batches), False)

            k_batch = wargs.batch_size * wargs.sort_k_batches
            number = int(math.ceil(train_size / k_batch))

            for start in range(number):
                bsrcs = srcs[start * k_batch : (start + 1) * k_batch]
                btrgs = trgs[start * k_batch : (start + 1) * k_batch]
                bslens = slens[start * k_batch : (start + 1) * k_batch]
                _, sorted_idx = tc.sort(tc.IntTensor(bslens))
                final_srcs += [bsrcs[k] for k in sorted_idx]
                final_trgs += [btrgs[k] for k in sorted_idx]

    wlog('Done.')

    return final_srcs, final_trgs
示例#4
0
def tokenize_lower(txt, char=False):
    txt = txt.strip().lower()
    if char is True: txt = zh_to_chars(txt)
    else: txt = txt.split()
    return txt