def main(unused_argv): del unused_argv # Unused corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset) save_dir = os.path.join(FLAGS.data_dir, "tfrecords") if not exists(save_dir): makedirs(save_dir) # test mode if FLAGS.per_host_test_bsz > 0: corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz, FLAGS.tgt_len, FLAGS=FLAGS) return for split, batch_size in zip( ["train", "valid"], [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]): if batch_size <= 0: continue print("Converting {} set...".format(split)) corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len, FLAGS=FLAGS)
def count_file(self, path, verbose=False, add_eos=False): if verbose: print('counting file {} ...'.format(path)) assert exists(path) sents = [] with open(path, 'r') as f: for idx, line in enumerate(f): if verbose and idx > 0 and idx % 500000 == 0: print(' line {}'.format(idx)) symbols = self.tokenize(line, add_eos=True) self.counter.update(symbols) sents.append(symbols) return sents
def get_lm_corpus(data_dir, dataset): fn = os.path.join(data_dir, "cache.pkl") if exists(fn): print("Loading cached dataset...") with open(fn, "rb") as fp: corpus = pickle.load(fp, encoding="latin1") else: print("Producing dataset...") kwargs = {} if dataset in ["wt103", "wt2", "sb2"]: kwargs["special"] = ["<eos>"] kwargs["lower_case"] = False elif dataset == "sb92": kwargs["special"] = ["<eos>"] kwargs["lower_case"] = False elif dataset == "wt103small": kwargs["special"] = ["<UNK>", "<eos>"] kwargs["lower_case"] = False kwargs["min_freq"] = 30 elif dataset == "ptb": kwargs["special"] = ["<eos>"] kwargs["lower_case"] = True elif dataset == "lm1b": kwargs["special"] = [] kwargs["lower_case"] = False kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt") elif dataset in ["enwik8", "text8"]: pass corpus = Corpus(data_dir, dataset, **kwargs) print("Saving dataset...") with open(fn, "wb") as fp: pickle.dump(corpus, fp, protocol=2) corpus_info = { "vocab_size": len(corpus.vocab), "cutoffs": corpus.cutoffs, "dataset": corpus.dataset } with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp: json.dump(corpus_info, fp) return corpus
def encode_file(self, path, ordered=False, verbose=False, add_double_eos=False): if verbose: print('encoding file {} ...'.format(path)) assert exists(path) encoded = [] with open(path, 'r') as f: for idx, line in enumerate(f): if verbose and idx > 0 and idx % 500000 == 0: print(' line {}'.format(idx)) symbols = self.tokenize(line, add_eos=True, add_double_eos=add_double_eos) encoded.append(self.convert_to_nparray(symbols)) if ordered: encoded = np.concatenate(encoded) return encoded
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False, ret_doc_boundary=False, pattern=None): if verbose: print('encoding file {} ...'.format(path)) assert exists(path) encoded = [] doc_boundary = [] with open(path, 'r') as f: for idx, line in enumerate(f): if verbose and idx > 0 and idx % 500000 == 0: print(' line {}'.format(idx)) if ret_doc_boundary: symbols, db = self.tokenize( line, add_eos=add_eos, add_double_eos=add_double_eos, ret_doc_boundary=ret_doc_boundary, pattern=pattern) doc_boundary.append(np.array(db, dtype=bool)) else: symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos) encoded.append(self.convert_to_nparray(symbols)) if ordered: encoded = np.concatenate(encoded) if ret_doc_boundary: doc_boundary = np.concatenate(doc_boundary) return encoded, doc_boundary return encoded