def count_vocab(data_file, max_vcb_size, max_seq_len=50, char=False): assert data_file and os.path.exists( data_file), 'need file to extract vocabulary ...' vocab = Vocab() #with open(data_file, 'r') as f: with io.open(data_file, encoding='utf-8') as f: for sent in f.readlines(): #sent = sent.strip().encode('utf-8') sent = sent.strip() if char is True: words = zh_to_chars(sent) else: words = sent.split() if len(words) > max_seq_len: continue for word in words: vocab.add(word) words_cnt = sum(vocab.freq.values()) new_vocab, new_words_cnt = vocab.keep_vocab_size(max_vcb_size) wlog('|Final vocabulary| / |Original vocabulary| = {} / {} = {:4.2f}%'. format(new_words_cnt, words_cnt, (new_words_cnt / words_cnt) * 100)) new_vocab.idx2key = {k: str(v) for k, v in new_vocab.idx2key.items()} new_vocab.key2idx = {str(k): v for k, v in new_vocab.key2idx.items()} return new_vocab
def wrap_tst_data(src_data, src_vocab, char=False): srcs, slens = [], [] srcF = io.open(src_data, mode='r', encoding='utf-8') idx = 0 while True: src_sent = srcF.readline() if src_sent == '': wlog('\nFinish to read monolingual test dataset {}, count {}'.format(src_data, idx)) break idx += 1 if src_sent == '': wlog('Error. Ignore abnormal blank sentence in line number {}'.format(idx)) sys.exit(0) src_sent = src_sent.strip() if char is True: src_sent = ' '.join(zh_to_chars(src_sent)) src_words = src_sent.split() src_len = len(src_words) srcs.append(src_vocab.keys2idx(src_words, UNK_WORD)) slens.append(src_len) srcF.close() return srcs, slens
def wrap_data(data_dir, file_prefix, src_suffix, trg_prefix, src_vocab, trg_vocab, shuffle=True, sort_data=True, max_seq_len=50, char=False): srcF = open(os.path.join(data_dir, '{}.{}'.format(file_prefix, src_suffix)), 'r') num = len(srcF.readlines()) srcF.close() point_every, number_every = int(math.ceil(num/100)), int(math.ceil(num/10)) srcF = io.open(os.path.join(data_dir, '{}.{}'.format(file_prefix, src_suffix)), mode='r', encoding='utf-8') trgFs = [] # maybe have multi-references for valid, we open them together for fname in os.listdir(data_dir): if fname.startswith('{}.{}'.format(file_prefix, trg_prefix)): wlog('\t{}'.format(os.path.join(data_dir, fname))) trgFs.append(open(os.path.join(data_dir, fname), 'r')) wlog('NOTE: Target side has {} references.'.format(len(trgFs))) idx, ignore, longer = 0, 0, 0 srcs, trgs, slens = [], [], [] while True: src_sent = srcF.readline().strip() if char is True: src_sent = ' '.join(zh_to_chars(src_sent)) trg_refs = [trgF.readline().strip() for trgF in trgFs] if src_sent == '' and all([trg_ref == '' for trg_ref in trg_refs]): wlog('\nFinish to read bilingual corpus.') break if numpy.mod(idx + 1, point_every) == 0: wlog('.', False) if numpy.mod(idx + 1, number_every) == 0: wlog('{}'.format(idx + 1), False) idx += 1 if src_sent == '' or any([trg_ref == '' for trg_ref in trg_refs]): wlog('Ignore abnormal blank sentence in line number {}'.format(idx)) ignore += 1 continue src_words = src_sent.split() src_len = len(src_words) trg_refs_words = [trg_ref.split() for trg_ref in trg_refs] if src_len <= max_seq_len and all([len(tws) <= max_seq_len for tws in trg_refs_words]): src_tensor = src_vocab.keys2idx(src_words, UNK_WORD) trg_refs_tensor = [trg_vocab.keys2idx(trg_ref_words, UNK_WORD, bos_word=BOS_WORD, eos_word=EOS_WORD) for trg_ref_words in trg_refs_words] srcs.append(src_tensor) trgs.append(trg_refs_tensor) slens.append(src_len) else: longer += 1 srcF.close() for trgF in trgFs: trgF.close() train_size = len(srcs) assert train_size == idx - ignore - longer, 'Wrong .. ' wlog('Sentence-pairs count: {}(total) - {}(ignore) - {}(longer) = {}'.format( idx, ignore, longer, idx - ignore - longer)) if shuffle is True: #assert len(trgFs) == 1, 'Unsupport to shuffle validation set.' wlog('Shuffling the whole dataset ... ', False) rand_idxs = tc.randperm(train_size).tolist() srcs = [srcs[k] for k in rand_idxs] trgs = [trgs[k] for k in rand_idxs] slens = [slens[k] for k in rand_idxs] final_srcs, final_trgs = srcs, trgs if sort_data is True: #assert len(trgFs) == 1, 'Unsupport to sort validation set in k batches.' final_srcs, final_trgs = [], [] if wargs.sort_k_batches == 0: wlog('Sorting the whole dataset by ascending order of source length ... ', False) # sort the whole training data by ascending order of source length _, sorted_idx = tc.sort(tc.IntTensor(slens)) final_srcs = [srcs[k] for k in sorted_idx] final_trgs = [trgs[k] for k in sorted_idx] else: wlog('Sorting for each {} batches ... '.format(wargs.sort_k_batches), False) k_batch = wargs.batch_size * wargs.sort_k_batches number = int(math.ceil(train_size / k_batch)) for start in range(number): bsrcs = srcs[start * k_batch : (start + 1) * k_batch] btrgs = trgs[start * k_batch : (start + 1) * k_batch] bslens = slens[start * k_batch : (start + 1) * k_batch] _, sorted_idx = tc.sort(tc.IntTensor(bslens)) final_srcs += [bsrcs[k] for k in sorted_idx] final_trgs += [btrgs[k] for k in sorted_idx] wlog('Done.') return final_srcs, final_trgs
def tokenize_lower(txt, char=False): txt = txt.strip().lower() if char is True: txt = zh_to_chars(txt) else: txt = txt.split() return txt