예제 #1
0
def main(unused_argv):
    del unused_argv  # Unused

    corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)

    save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
    if not exists(save_dir):
        makedirs(save_dir)

    # test mode
    if FLAGS.per_host_test_bsz > 0:
        corpus.convert_to_tfrecords("test",
                                    save_dir,
                                    FLAGS.per_host_test_bsz,
                                    FLAGS.tgt_len,
                                    FLAGS=FLAGS)
        return

    for split, batch_size in zip(
        ["train", "valid"],
        [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):

        if batch_size <= 0: continue
        print("Converting {} set...".format(split))
        corpus.convert_to_tfrecords(split,
                                    save_dir,
                                    batch_size,
                                    FLAGS.tgt_len,
                                    FLAGS=FLAGS)
예제 #2
0
    def count_file(self, path, verbose=False, add_eos=False):
        if verbose: print('counting file {} ...'.format(path))
        assert exists(path)

        sents = []
        with open(path, 'r') as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line, add_eos=True)
                self.counter.update(symbols)
                sents.append(symbols)

        return sents
예제 #3
0
def get_lm_corpus(data_dir, dataset):
    fn = os.path.join(data_dir, "cache.pkl")

    if exists(fn):
        print("Loading cached dataset...")
        with open(fn, "rb") as fp:
            corpus = pickle.load(fp, encoding="latin1")
    else:
        print("Producing dataset...")
        kwargs = {}
        if dataset in ["wt103", "wt2", "sb2"]:
            kwargs["special"] = ["<eos>"]
            kwargs["lower_case"] = False
        elif dataset == "sb92":
            kwargs["special"] = ["<eos>"]
            kwargs["lower_case"] = False
        elif dataset == "wt103small":
            kwargs["special"] = ["<UNK>", "<eos>"]
            kwargs["lower_case"] = False
            kwargs["min_freq"] = 30
        elif dataset == "ptb":
            kwargs["special"] = ["<eos>"]
            kwargs["lower_case"] = True
        elif dataset == "lm1b":
            kwargs["special"] = []
            kwargs["lower_case"] = False
            kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt")
        elif dataset in ["enwik8", "text8"]:
            pass

        corpus = Corpus(data_dir, dataset, **kwargs)

        print("Saving dataset...")
        with open(fn, "wb") as fp:
            pickle.dump(corpus, fp, protocol=2)

        corpus_info = {
            "vocab_size": len(corpus.vocab),
            "cutoffs": corpus.cutoffs,
            "dataset": corpus.dataset
        }
        with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
            json.dump(corpus_info, fp)

    return corpus
예제 #4
0
    def encode_file(self,
                    path,
                    ordered=False,
                    verbose=False,
                    add_double_eos=False):
        if verbose: print('encoding file {} ...'.format(path))
        assert exists(path)
        encoded = []
        with open(path, 'r') as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line,
                                        add_eos=True,
                                        add_double_eos=add_double_eos)

                encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded
예제 #5
0
    def encode_file(self,
                    path,
                    ordered=False,
                    verbose=False,
                    add_eos=True,
                    add_double_eos=False,
                    ret_doc_boundary=False,
                    pattern=None):
        if verbose: print('encoding file {} ...'.format(path))
        assert exists(path)
        encoded = []
        doc_boundary = []
        with open(path, 'r') as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                if ret_doc_boundary:
                    symbols, db = self.tokenize(
                        line,
                        add_eos=add_eos,
                        add_double_eos=add_double_eos,
                        ret_doc_boundary=ret_doc_boundary,
                        pattern=pattern)
                    doc_boundary.append(np.array(db, dtype=bool))
                else:
                    symbols = self.tokenize(line,
                                            add_eos=add_eos,
                                            add_double_eos=add_double_eos)
                encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)
            if ret_doc_boundary:
                doc_boundary = np.concatenate(doc_boundary)
                return encoded, doc_boundary

        return encoded