Example #1
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, "train.txt"))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
        self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)

        vocab_len = len(self.vocab)
        self.cutoffs = [0, int(vocab_len * 0.1), int(vocab_len * 0.2), int(vocab_len * 0.4)] + [vocab_len]
        # self.cutoffs = []

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, num_core_per_host, **kwargs):
        file_names = []

        record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        data = getattr(self, split)

        file_name, num_batch = create_ordered_tfrecords(save_dir, split, data, bsz, tgt_len)
        file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):

        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(
            path, "train.txt"))  # 更新vocab对象里的counter(用于统计每个不同的词出现的次数)
        self.vocab.count_file(os.path.join(path, "valid.txt"))  # 同上,验证集中更新

        self.vocab.build_vocab()  # 这一步是为了建立idx2sym和sym2idx,把词映射为索引,把索引还原为词

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                            ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            ordered=True)

        # self.cutoffs = []  # 完全是多余的,从看代码的第一天开始,我就觉得cutoff是多余的,在今天被坑了一天之后,我终于可以确定在没有TPU的情况下,所有设涉及cutoff的代码都是多余的

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, **kwargs):
        file_names = []

        record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
            split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)
        bin_sizes = None

        file_name, num_batch = create_ordered_tfrecords(
            save_dir, split, getattr(self, split), bsz, tgt_len)

        file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "bin_sizes": bin_sizes,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)
def inference(n_token, cutoffs, ps_device):
    dataset_name = "doupo"
    tmp_Vocab = Vocab()
    tmp_Vocab.count_file("../data/{}/train.txt".format(dataset_name), add_eos=False)
    tmp_Vocab.build_vocab()

    n_token = len(tmp_Vocab)
    # print(tmp_Vocab.idx2sym)

    test_list = tf.placeholder(tf.int64, shape=[1, None])
    dataset = tf.data.Dataset.from_tensors(test_list)
    # dataset = dataset.batch(1, drop_remainder=True)

    iterator = dataset.make_initializable_iterator()
    input_feed = iterator.get_next()

    inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
    # inputs = input_feed

    per_core_bsz = 1
    tower_mems, tower_losses, tower_new_mems = [], [], []
    tower_output = []
    tower_mems_id = []
    tower_new_mems_id = []
    tower_attn_prob = []

    for i in range(FLAGS.num_core_per_host):
        with tf.device(assign_to_gpu(i, ps_device)), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            mems_i = [tf.placeholder(tf.float32,
                                     [FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
                      for _ in range(FLAGS.n_layer)]

            mems_i_id = [tf.placeholder(tf.int64,
                                     [FLAGS.mem_len, per_core_bsz])
                      for _ in range(FLAGS.n_layer)]

            new_mems_i, output_i, new_mems_i_id, attn_prob_i = single_core_graph_for_inference(
                n_token=n_token,
                cutoffs=cutoffs,
                is_training=False,
                inp=inputs[i],
                mems=mems_i,
                mems_id=mems_i_id)

            tower_mems.append(mems_i)
            tower_new_mems.append(new_mems_i)
            tower_output.append(output_i)
            tower_mems_id.append(mems_i_id)
            tower_new_mems_id.append(new_mems_i_id)
            tower_attn_prob.append(attn_prob_i)

    # Evaluation loop
    tower_mems_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    tower_mems_id_np = [
        [np.zeros([FLAGS.mem_len, per_core_bsz], dtype=np.float32)
         for layer in range(FLAGS.n_layer)]
        for core in range(FLAGS.num_core_per_host)
    ]

    saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.eval_ckpt_path is None:
            eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        else:
            eval_ckpt_path = FLAGS.eval_ckpt_path
        print('eval_ckpt_path:', eval_ckpt_path)
        saver.restore(sess, eval_ckpt_path)

        # attention_score = tf.get_variable('transformer/layer_2/rel_attn/transpose_1:0')

        fetches = [tower_new_mems,
                   tower_output,
                   tower_new_mems_id,
                   tower_attn_prob,
                   'transformer/adaptive_embed/lookup_table:0']

        while True:
            input_text = input("seed text >>> ")
            while not input_text:
                print('Prompt should not be empty!')
                input_text = input("Model prompt >>> ")
            encoded_input = tmp_Vocab.encode_sents(input_text, ordered=True)

            with open('{}.txt'.format(dataset_name), 'a') as f:
                f.write('-' * 100+'\n')
                f.write('input:\n')
                f.write(input_text+'\n')

            output_len = 200
            progress = ProgressBar()
            for step in progress(range(output_len)):
                time.sleep(0.01)
                feed_dict = {}
                for i in range(FLAGS.num_core_per_host):
                    for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                        feed_dict[m] = m_np

                    for id, id_np in zip(tower_mems_id[i], tower_mems_id_np[i]):
                        feed_dict[id] = id_np

                sess.run(iterator.initializer, feed_dict={test_list: [encoded_input]})
                fetched = sess.run(fetches, feed_dict=feed_dict)

                tower_mems_np, output = fetched[:2]

                tower_mems_id_np = fetched[2]

                attn_prob = fetched[3]
                lookup_table = fetched[4]
                # print(attention_score)
                # print(np.array(lookup_table).shape)
                # print(np.array(tower_mems_id_np).shape)

                tmp_list = output[0][-1][0]
                tmp_list = tmp_list.tolist()

                # 下面是对结果的6种处理方式,若需要就保留,然后注释掉其他几种
                # todo 取top1
                index = top_one_result(tmp_list)
                # todo diversity
                # index = gen_diversity(tmp_list)
                # todo base on keyword
                # index = gen_on_keyword(tmp_Vocab, '喜', tmp_list, lookup_table)

                # # todo 可视化候选词
                # visualize_prob(tmp_Vocab, tmp_list,
                # '../exp_result/{}/candidates'.format(dataset_name+'mem_len500'), len(input_text))

                # # # todo 可视化attention per layer
                # visualize_attention_per_layer(tmp_Vocab, tower_mems_id_np, attn_prob, index,
                #                               '../exp_result/{}/attention_per_layer'.format(dataset_name+'mem_len500'),
                #                               len(input_text))

                # # # todo 可视化attention per head
                # visualize_attention_per_head(tmp_Vocab, tower_mems_id_np, attn_prob, index,
                #                              '../exp_result/{}/attention_per_head'.format(dataset_name+'_repeat'),
                #                              len(input_text))

                input_text += tmp_Vocab.get_sym(index) if tmp_Vocab.get_sym(index) != '<eos>' else '\n'
                encoded_input = [index]

            print(input_text)

            with open('{}.txt'.format(dataset_name), 'a') as f:
                f.write('output:\n')
                f.write(input_text+'\n')
                f.write('-'*100+'\n')
Example #4
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        if self.dataset == "generic_dataset":
            encode_kwargs = dict(
                add_eos=kwargs.pop('add_eos', False),
                add_double_eos=kwargs.pop('add_double_eos', False),
                ordered=True,
                verbose=True,
            )
            if kwargs.get('vocab_file') is not None:
                kwargs['vocab_file'] = os.path.join(path, kwargs['vocab_file'])

        print(self.dataset, 'vocab params', kwargs)
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
            self.vocab.count_file(os.path.join(path, "train.txt"))
            self.vocab.count_file(os.path.join(path, "valid.txt"))
            self.vocab.count_file(os.path.join(path, "test.txt"))
        elif self.dataset == "generic_dataset" and not self.vocab.vocab_file:
            self.vocab.count_file(os.path.join(path, "train.txt"))
            self.vocab.count_file(os.path.join(path, "valid.txt"))
            self.vocab.count_file(os.path.join(path, "test.txt"))
        elif self.dataset == "wt103":
            self.vocab.count_file(os.path.join(path, "train.txt"))
        elif self.dataset == "lm1b":
            train_path_pattern = os.path.join(
                path, "1-billion-word-language-modeling-benchmark-r13output",
                "training-monolingual.tokenized.shuffled", "news.en-*")
            train_paths = glob(train_path_pattern)

            # the vocab will load from file when build_vocab() is called
            # for train_path in sorted(train_paths):
            #   self.vocab.count_file(train_path, verbose=True)

        self.vocab.build_vocab()

        if self.dataset in ["ptb", "wt2", "wt103"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True)
        elif self.dataset == "generic_dataset":
            self.train = self.vocab.encode_file(
                os.path.join(path, "train.txt"), **encode_kwargs)
            self.valid = self.vocab.encode_file(
                os.path.join(path, "valid.txt"), **encode_kwargs)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               **encode_kwargs)
        elif self.dataset in ["enwik8", "text8"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == "lm1b":
            self.train = train_paths
            valid_path = os.path.join(path, "valid.txt")
            test_path = valid_path
            self.valid = self.vocab.encode_file(valid_path,
                                                ordered=True,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(test_path,
                                               ordered=True,
                                               add_double_eos=True)

        if self.dataset == "wt103":
            self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]
        elif self.dataset == "generic_dataset":
            with open(os.path.join(path, "cutoffs.json")) as f:
                self.cutoffs = json.load(f)
        elif self.dataset == "lm1b":
            self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)]
        else:
            self.cutoffs = []

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
                             num_core_per_host, **kwargs):
        FLAGS = kwargs.get('FLAGS')

        file_names = []
        use_tpu = FLAGS.use_tpu and not (split == "test"
                                         and num_core_per_host == 1)

        if use_tpu:
            record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
                split, bsz, tgt_len, num_core_per_host)
        else:
            record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
                split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        if self.dataset in [
                "ptb", "wt2", "wt103", "enwik8", "text8", "generic_dataset"
        ]:
            data = getattr(self, split)
            bin_sizes = get_bin_sizes(data, bsz // num_core_per_host, tgt_len,
                                      self.cutoffs)
            file_name, num_batch = create_ordered_tfrecords(
                save_dir,
                split,
                data,
                bsz,
                tgt_len,
                num_core_per_host,
                self.cutoffs,
                bin_sizes,
                num_passes=FLAGS.num_passes
                if split == 'train' and use_tpu else 1,
                use_tpu=use_tpu)
            file_names.append(file_name)
        elif self.dataset == "lm1b":
            bin_sizes = get_bin_sizes(self.valid, bsz // num_core_per_host,
                                      tgt_len, self.cutoffs)
            if split == "train":
                np.random.seed(123456)
                num_batch = 0

                if FLAGS.num_procs > 1:
                    _preprocess_wrapper = partial(
                        _preprocess,
                        train=self.train,
                        vocab=self.vocab,
                        save_dir=save_dir,
                        cutoffs=self.cutoffs,
                        bin_sizes=bin_sizes,
                        bsz=bsz,
                        tgt_len=tgt_len,
                        num_core_per_host=num_core_per_host,
                        use_tpu=use_tpu,
                        num_shuffle=FLAGS.num_shuffle)

                    pool = mp.Pool(processes=FLAGS.num_procs)
                    results = pool.map(_preprocess_wrapper,
                                       range(len(self.train)))
                    for res in results:
                        file_names.extend(res[0])
                        num_batch += res[1]
                else:
                    for shard, path in enumerate(self.train):
                        data_shard = self.vocab.encode_file(
                            path, ordered=False, add_double_eos=True)

                        num_shuffle = FLAGS.num_shuffle

                        for shuffle in range(num_shuffle):
                            print("Processing shard {} shuffle {}".format(
                                shard, shuffle))
                            basename = "train-{:03d}-{:02d}".format(
                                shard, shuffle)
                            np.random.shuffle(data_shard)
                            file_name, num_batch_ = create_ordered_tfrecords(
                                save_dir,
                                basename,
                                np.concatenate(data_shard),
                                bsz,
                                tgt_len,
                                num_core_per_host,
                                self.cutoffs,
                                bin_sizes,
                                use_tpu=use_tpu)
                            file_names.append(file_name)
                            num_batch += num_batch_

            else:
                file_name, num_batch = create_ordered_tfrecords(
                    save_dir,
                    split,
                    getattr(self, split),
                    bsz,
                    tgt_len,
                    num_core_per_host,
                    self.cutoffs,
                    bin_sizes,
                    use_tpu=use_tpu)
                file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "bin_sizes": bin_sizes,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)
Example #5
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)
        if self.vocab.vocab_file == None:
            self.vocab.count_file(os.path.join(path, "train.txt"))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                            add_eos=True,
                                            ordered=True)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            add_eos=True,
                                            ordered=True)
        self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                           add_eos=True,
                                           ordered=True)

        self.cutoffs = []

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
                             num_core_per_host, **kwargs):
        FLAGS = kwargs.get('FLAGS')

        file_names = []
        use_tpu = FLAGS.use_tpu and not (split == "test"
                                         and num_core_per_host == 1)

        if use_tpu:
            record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
                split, bsz, tgt_len, num_core_per_host)
        else:
            record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
                split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        data = getattr(self, split)

        bin_sizes = get_bin_sizes(data, bsz // num_core_per_host, tgt_len,
                                  self.cutoffs)
        file_name, num_batch = create_ordered_tfrecords(
            save_dir,
            split,
            data,
            bsz,
            tgt_len,
            num_core_per_host,
            self.cutoffs,
            bin_sizes,
            num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1,
            use_tpu=use_tpu)
        file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "bin_sizes": bin_sizes,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.logging.info("n_token {}".format(n_token))

    tmp_Vocab = Vocab(special=["<bos>", "<eos>", "<UNK>"])
    tmp_Vocab.count_file("../data/{}/train.txt".format(FLAGS.dataset),
                         add_eos=False)
    tmp_Vocab.build_vocab()

    if FLAGS.do_sent_ppl_pred:
        encoded_txt_input = []
        txt_input = []
        input_csv = []
        with open(FLAGS.input_file_dir, "r") as read_file:
            csv_reader = csv.reader(read_file)
            for line in csv_reader:
                if line[0].strip() != 0:
                    input_csv.append(line)

            for i in range(1, len(input_csv)):
                txt_input.append(input_csv[i][0].strip())
                encoded_txt_input.append(list(tmp_Vocab.encode_sents(input_csv[i][0].strip(), \
                    add_eos=True, ordered=True)))

        encoded_txt_input = [
            line[:FLAGS.limit_len] if len(line) > FLAGS.limit_len else line
            for line in encoded_txt_input
        ]
        encoded_txt_input = np.array(encoded_txt_input)

        input_csv[0].append("ppl")

        pool = multiprocessing.Pool(FLAGS.multiprocess)

        parti_len = len(encoded_txt_input) // FLAGS.multiprocess
        pro_res_l = []

        for i in range(FLAGS.multiprocess):
            print("Setting process-%s" % i)
            ### 有空这里要写一个控制使用gpu:xx的步骤(gpu:1满了就用下一个)

            if i + 1 == FLAGS.multiprocess:
                end = len(encoded_txt_input)
            else:
                end = (i + 1) * parti_len
            pro_res_l.append(pool.apply_async(sent_ppl, \
                args=(encoded_txt_input[i*parti_len:end], n_token, cutoffs, "/gpu:1")))

        res_l = []

        for i in range(len(pro_res_l)):
            proc_i_res = pro_res_l[i].get()
            res_l.extend(proc_i_res)

        pool.close()
        pool.join()
        print('All subprocesses done.')

        tf.logging.info('#time: {}'.format(time.time()))

        for i in range(1, len(input_csv)):
            input_csv[i].append(res_l[i - 1])
        output_df = pd.DataFrame(input_csv[1:], columns=input_csv[0])
        output_df.to_csv(FLAGS.output_file_dir,
                         sep=",",
                         index=False,
                         encoding="utf-8-sig")

        with open("non_batch_ref_output.txt", "w") as write_res:
            for i in range(len(txt_input)):
                write_res.write(txt_input[i] + " " +
                                str(encoded_txt_input[i]) + " " +
                                str(res_l[i]) + "\n")

        # Check whether the length of result is right; Make sure multiprocess work well
        print(len(res_l))

    elif FLAGS.do_sent_gen:
        txt_gen_list = []
        with open(FLAGS.input_txt_dir, "r") as read_txt:
            for input_txt in read_txt:
                if len(input_txt.strip()) != 0:
                    txt_gen_list.append(
                        sent_gen(tmp_Vocab, input_txt.strip(), n_token,
                                 cutoffs, "/gpu:1"))

        with open("sent_generation.txt", "w") as write_res:
            for line in txt_gen_list:
                write_res.write(line + "\n")
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.logging.info("n_token {}".format(n_token))

    tmp_Vocab = Vocab(special=["<bos>", "<eos>", "<UNK>"])
    tmp_Vocab.count_file("../data/{}/train.txt".format(FLAGS.dataset),
                         add_eos=False)
    tmp_Vocab.build_vocab()

    if FLAGS.do_sent_ppl_pred:
        encoded_txt_input = []
        txt_input = []
        input_csv = []
        with open(FLAGS.input_file_dir, "r") as read_file:
            csv_reader = csv.reader(read_file)
            for line in csv_reader:
                if line[0].strip() != 0:
                    input_csv.append(line)

            for i in range(1, len(input_csv)):
                txt_input.append(input_csv[i][0].strip())
                encoded_txt_input.append(list(tmp_Vocab.encode_sents(input_csv[i][0].strip(), \
                    add_eos=True, ordered=True)))

        #txt_input = txt_input[:7]
        #encoded_txt_input = encoded_txt_input[:7]

        encoded_txt_input_len = [len(encoded_txt) if len(encoded_txt) <= FLAGS.limit_len else FLAGS.limit_len \
             for encoded_txt in encoded_txt_input]

        encoded_txt_input = [
            cut_pad(line, FLAGS.limit_len, 0) for line in encoded_txt_input
        ]

        if len(encoded_txt_input) % FLAGS.pred_batch_size != 0:
            pad_len = FLAGS.pred_batch_size - (len(encoded_txt_input) %
                                               FLAGS.pred_batch_size)
            encoded_txt_input = encoded_txt_input + \
                [[0]*FLAGS.limit_len]*pad_len
            encoded_txt_input_len = encoded_txt_input_len + \
                [FLAGS.limit_len]*pad_len

        encoded_txt_input = np.array(encoded_txt_input).reshape(
            FLAGS.pred_batch_size, -1, FLAGS.limit_len)
        encoded_txt_input_len = np.array(encoded_txt_input_len).reshape(
            FLAGS.pred_batch_size, -1)
        input_csv[0].append("ppl")

        if FLAGS.multiprocess == 1 or encoded_txt_input.shape[
                1] // FLAGS.multiprocess == 0:
            ppl_list = sent_ppl((encoded_txt_input, encoded_txt_input_len),
                                n_token, cutoffs, "/gpu:1")

            for i in range(1, len(input_csv)):
                input_csv[i].append(ppl_list[i - 1])
            output_df = pd.DataFrame(input_csv[1:], columns=input_csv[0])
            output_df.to_csv(FLAGS.output_file_dir,
                             sep=",",
                             index=False,
                             encoding="utf-8-sig")

            with open("sent_ppl_pred.txt", "w") as write_res:
                for i in range(len(txt_input)):
                    write_res.write(txt_input[i] + "\t" + str(ppl_list[i]) +
                                    "\n")

            # Check whether the length of result is right; Make sure batch-predict work well
            print(len(ppl_list))
        else:
            pool = multiprocessing.Pool(FLAGS.multiprocess)
            parti_batch_num = encoded_txt_input.shape[1] // FLAGS.multiprocess
            pro_res_l = []

            for i in range(FLAGS.multiprocess):
                print("Setting process-%s" % i)
                ### 有空这里要写一个控制使用gpu:xx的步骤(gpu:1满了就用下一个)

                if i + 1 == FLAGS.multiprocess:
                    end = encoded_txt_input.shape[1]
                else:
                    end = (i + 1) * parti_batch_num
                pro_res_l.append(pool.apply_async(sent_ppl, \
                    args=((encoded_txt_input[:,i*parti_batch_num:end,:], \
                        encoded_txt_input_len[:,i*parti_batch_num:end]), \
                        n_token, cutoffs, "/gpu:1")))

            res_l = [[] for _ in range(FLAGS.pred_batch_size)]

            for i in range(len(pro_res_l)):
                proc_i_res = pro_res_l[i].get()
                parti_len = len(proc_i_res) // FLAGS.pred_batch_size
                for j in range(FLAGS.pred_batch_size):
                    res_l[j].extend(proc_i_res[j * parti_len:(j + 1) *
                                               parti_len])

            pool.close()
            pool.join()
            print('All subprocesses done.')

            res_merge = []
            for i in range(FLAGS.pred_batch_size):
                res_merge.extend(res_l[i])
            tf.logging.info('#time: {}'.format(time.time()))

            for i in range(1, len(input_csv)):
                input_csv[i].append(res_merge[i - 1])
            output_df = pd.DataFrame(input_csv[1:], columns=input_csv[0])
            output_df.to_csv(FLAGS.output_file_dir,
                             sep=",",
                             index=False,
                             encoding="utf-8-sig")

            with open("sent_ppl_pred.txt", "w") as write_res:
                for i in range(len(txt_input)):
                    write_res.write(txt_input[i] + "\t" + str(res_merge[i]) +
                                    "\n")

            # Check whether the length of result is right; Make sure multiprocess work well
            print(len(res_merge))
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ["ptb", "wt2", "enwik8", "text8", "sb2", "sb92"]:
            self.vocab.count_file(os.path.join(path, "train.txt"))
            self.vocab.count_file(os.path.join(path, "valid.txt"))
            self.vocab.count_file(os.path.join(path, "test.txt"))
        elif self.dataset in ["wt103", "wt103small"]:
            self.vocab.count_file(os.path.join(path, "train.txt"))
        elif self.dataset == "lm1b":
            train_path_pattern = os.path.join(
                path, "1-billion-word-language-modeling-benchmark-r13output",
                "training-monolingual.tokenized.shuffled", "news.en-*")
            train_paths = glob(train_path_pattern)

            # the vocab will load from file when build_vocab() is called
            # for train_path in sorted(train_paths):
            #   self.vocab.count_file(train_path, verbose=True)

        self.vocab.build_vocab()

        if self.dataset in ["ptb", "sb2", "sb92"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True)
        elif self.dataset in ["wt2", "wt103", "wt103small"]:
            self.train, self.train_boundary = self.vocab.encode_file(
                os.path.join(path, "train.txt"),
                ordered=True,
                ret_doc_boundary=True,
                pattern="\=[^=]+\=")
            self.valid, self.valid_boundary = self.vocab.encode_file(
                os.path.join(path, "valid.txt"),
                ordered=True,
                ret_doc_boundary=True,
                pattern="\=[^=]+\=")
            self.test, self.test_boundary = self.vocab.encode_file(
                os.path.join(path, "test.txt"),
                ordered=True,
                ret_doc_boundary=True,
                pattern="\=[^=]+\=")
        elif self.dataset in ["enwik8", "text8"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == "lm1b":
            self.train = train_paths
            valid_path = os.path.join(path, "valid.txt")
            test_path = valid_path
            self.valid = self.vocab.encode_file(valid_path,
                                                ordered=True,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(test_path,
                                               ordered=True,
                                               add_double_eos=True)

        if self.dataset == "sb92":
            self.cutoffs = [0, 10000, 20000] + [len(self.vocab)]
        elif self.dataset == "wt103small":
            self.cutoffs = [0, 20000, 40000] + [len(self.vocab)]
        elif self.dataset == "wt103":
            self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]
        elif self.dataset == "lm1b":
            self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)]
        else:
            self.cutoffs = []

    def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, **kwargs):
        FLAGS = kwargs.get('FLAGS')

        file_names = []

        record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
            split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        if self.dataset in ["ptb", "enwik8", "text8", "sb2", "sb92"]:
            data = getattr(self, split)
            file_name, num_batch = create_ordered_tfrecords(
                save_dir,
                split,
                data,
                bsz,
                tgt_len,
                num_passes=FLAGS.num_passes)
            file_names.append(file_name)
        if self.dataset in ["wt2", "wt103", "wt103small"]:
            data = getattr(self, split)
            boundary = getattr(self, split + "_boundary")
            file_name, num_batch = create_ordered_tfrecords(
                save_dir,
                split,
                data,
                bsz,
                tgt_len,
                num_passes=FLAGS.num_passes,
                boundary=boundary)
            file_names.append(file_name)
        elif self.dataset == "lm1b":
            if split == "train":
                np.random.seed(123456)
                num_batch = 0

                if FLAGS.num_procs > 1:
                    _preprocess_wrapper = partial(
                        _preprocess,
                        train=self.train,
                        vocab=self.vocab,
                        save_dir=save_dir,
                        bsz=bsz,
                        tgt_len=tgt_len,
                        num_shuffle=FLAGS.num_shuffle)

                    pool = mp.Pool(processes=FLAGS.num_procs)
                    results = pool.map(_preprocess_wrapper,
                                       range(len(self.train)))
                    for res in results:
                        file_names.extend(res[0])
                        num_batch += res[1]
                else:
                    for shard, path in enumerate(self.train):
                        data_shard = self.vocab.encode_file(
                            path, ordered=False, add_double_eos=True)

                        num_shuffle = FLAGS.num_shuffle

                        for shuffle in range(num_shuffle):
                            print("Processing shard {} shuffle {}".format(
                                shard, shuffle))
                            basename = "train-{:03d}-{:02d}".format(
                                shard, shuffle)
                            np.random.shuffle(data_shard)
                            file_name, num_batch_ = create_ordered_tfrecords(
                                save_dir, basename, np.concatenate(data_shard),
                                bsz, tgt_len)
                            file_names.append(file_name)
                            num_batch += num_batch_

            else:
                file_name, num_batch = create_ordered_tfrecords(
                    save_dir, split, getattr(self, split), bsz, tgt_len)
                file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {"filenames": file_names, "num_batch": num_batch}
            if self.dataset in ["wt2", "wt103", "wt103small"]:
                record_info["boundary"] = True
            else:
                record_info["boundary"] = False
            json.dump(record_info, fp)
from vocabulary import Vocab
import csv

tmp_Vocab = Vocab()
tmp_Vocab.count_file("../data/test/train.txt", add_eos=False)
tmp_Vocab.build_vocab()

with open('../data/test/label.tsv', 'wt') as f:
    tsv_writer = csv.writer(f, delimiter='\t')
    tsv_writer.writerow(['label', 'index'])
    for i in range(len(tmp_Vocab.idx2sym)):
        tsv_writer.writerow([tmp_Vocab.idx2sym[i], i])
        # tsv_writer.writerow([tmp_Vocab.idx2sym[i]])