Beispiel #1
0
 def __init__(self,
              config,
              thread_num=None,
              input_queue=None,
              output_queue=None,
              job_name='GPTPreprocessor'):
     super(MyPreprocessor, self).__init__(job_name, thread_num, input_queue, output_queue)
     self.first_sequence = config.first_sequence
     self.sequence_length = config.sequence_length
     self.tokenizer = tokenization.FullTokenizer(vocab_file=config.vocab_file_path , do_lower_case=True)
Beispiel #2
0
    def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un",
            "runn", "##ing", ","
        ]
        with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
            if six.PY2:
                vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
            else:
                vocab_writer.write("".join([x + "\n" for x in vocab_tokens
                                            ]).encode("utf-8"))

            vocab_file = vocab_writer.name

        tokenizer = tokenization.FullTokenizer(vocab_file)
        os.unlink(vocab_file)

        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
        self.assertAllEqual(tokens,
                            ["un", "##want", "##ed", ",", "runn", "##ing"])

        self.assertAllEqual(tokenizer.convert_tokens_to_ids(tokens),
                            [7, 4, 5, 10, 8, 9])
Beispiel #3
0
                    type=str,
                    help='Base filename to use. THIS MUST BE A LOCAL FILE.')
parser.add_argument(
    '-max_seq_length',
    dest='max_seq_length',
    default=1024,
    type=int,
    help='Max sequence length',
)

args = parser.parse_args()
random.seed(args.seed + args.fold)

print("now begin...")
tokenizer = tokenization.FullTokenizer(
    vocab_file="D:/EssayKiller_V1/AutoWritter/dataset/clue-vocab.txt",
    do_lower_case=True)


class S3TFRecordWriter(object):
    def __init__(self, fn):
        self.fn = fn
        self.s3client = None
        self.gclient = None
        self.bucket_name = None
        self.file_name = None
        self.storage_dir = None
        self.writer = tf.python_io.TFRecordWriter(fn)

    def write(self, x):
        self.writer.write(x)
Beispiel #4
0
                    default='realnews.jsonl',
                    type=str,
                    help='Base filename to use. THIS MUST BE A LOCAL FILE.')
parser.add_argument(
    '-max_seq_length',
    dest='max_seq_length',
    default=1025,
    type=int,
    help='Max sequence length',
)

args = parser.parse_args()
random.seed(args.seed + args.fold)

#encoder = get_encoder()
tokenizer = tokenization.FullTokenizer(
    vocab_file="bert-base-chinese-vocab.txt", do_lower_case=True)


class TFRecordWriter(object):
    def __init__(self, fn):
        self.fn = fn
        if fn.startswith('gs://'):
            from google.cloud import storage
            self.s3client = None
            self.gclient = storage.Client()
            self.storage_dir = TemporaryDirectory()
            self.writer = tf.python_io.TFRecordWriter(
                os.path.join(self.storage_dir.name, 'temp.tfrecord'))
            self.bucket_name, self.file_name = self.fn.split('gs://',
                                                             1)[1].split(
                                                                 '/', 1)
    help='Base filename to use. THIS MUST BE A LOCAL FILE.'
)
parser.add_argument(
    '-max_seq_length',
    dest='max_seq_length',
    default=1025,
    type=int,
    help='Max sequence length',
)


args = parser.parse_args()
random.seed(args.seed + args.fold)

#encoder = get_encoder()
tokenizer = tokenization.FullTokenizer(
    vocab_file="/gdrive/My Drive/gpt2-ml-Finetune-1.0/tokenization/bert-base-chinese-vocab.txt", do_lower_case=True)


class TFRecordWriter(object):
    def __init__(self, fn):
        self.fn = fn
        if fn.startswith('gs://'):
            from google.cloud import storage
            self.s3client = None
            self.gclient = storage.Client()
            self.storage_dir = TemporaryDirectory()
            self.writer = tf.python_io.TFRecordWriter(
                os.path.join(self.storage_dir.name, 'temp.tfrecord'))
            self.bucket_name, self.file_name = self.fn.split(
                'gs://', 1)[1].split('/', 1)
Beispiel #6
0
    return {
        'extraction':
        tokenization.printable_text(''.join(
            tokenizer.convert_ids_to_tokens(output_tokens))),
        'start_ind':
        start_ind,
        'end_ind':
        end_ind,
    }


args = parser.parse_args()
proj_root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
vocab_file_path = os.path.join(proj_root_path, "tokenization/clue-vocab.txt")
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file_path,
                                       do_lower_case=True)
news_config = GroverConfig.from_json_file(args.config_fn)

# We might have to split the batch into multiple chunks if the batch size is too large
default_mbs = {12: 32, 24: 16, 48: 3}
max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[
    news_config.num_hidden_layers]

# factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size
num_chunks = int(np.ceil(args.batch_size / max_batch_size))
batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))

# This controls the top p for each generation.
top_p = np.ones(
    (num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p
Beispiel #7
0
                    help='Base filename to use. THIS MUST BE A LOCAL FILE.')
parser.add_argument(
    '-max_seq_length',
    dest='max_seq_length',
    default=1024,
    type=int,
    help='Max sequence length',
)

args = parser.parse_args()

args.input_fn = '166893.json'

random.seed(args.seed + args.fold)

tokenizer = tokenization.FullTokenizer(vocab_file="clue-vocab.txt",
                                       do_lower_case=True)


class S3TFRecordWriter(object):
    def __init__(self, fn):
        self.fn = fn
        if fn.startswith('s3://'):
            from boto3.s3.transfer import TransferConfig
            import boto3
            self.gclient = None
            self.s3client = boto3.client('s3', )
            self.storage_dir = TemporaryDirectory()
            self.writer = tf.python_io.TFRecordWriter(
                os.path.join(self.storage_dir.name, 'temp.tfrecord'))
            self.bucket_name, self.file_name = self.fn.split('s3://',
                                                             1)[1].split(
Beispiel #8
0
    end_ind = output_tokens.shape[0]

    return {
        'extraction':
        tokenization.printable_text(
            tokenizer.convert_ids_to_tokens(output_tokens)),
        'start_ind':
        start_ind,
        'end_ind':
        end_ind,
    }


args = parser.parse_args()
proj_root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
tokenizer = tokenization.FullTokenizer(vocab_file='./kotok',
                                       do_lower_case=True)
news_config = GroverConfig.from_json_file(args.config_fn)

# We might have to split the batch into multiple chunks if the batch size is too large
default_mbs = {12: 32, 24: 16, 48: 3}
max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[
    news_config.num_hidden_layers]

# factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size
num_chunks = int(np.ceil(args.batch_size / max_batch_size))
batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))

# This controls the top p for each generation.
top_p = np.ones(
    (num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p
Beispiel #9
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "tag": SingleTagPredictionProcessor,
        "multitag": MultiTagPredictionProcessor,
        "phrase": ExtractPhrasesProcessor,
        "seg-phrase": ExtractPhrasesFromSegmentedInputProcessor,
        "all-phrase": ExtractAllPhrasesProcessor,
        "phrase-and-tag": ExtractAllPhrasesAndTagsProcessor,
        "alimama": AlimamaTitleClassificationPorcessor,
        'intent': IntentClassificationProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    processor = processors[task_name](FLAGS.data_dir, tokenizer,
                                      FLAGS.max_seq_length)

    label_list = processor.label_list

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples()
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(processor,
                                bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))
    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = TPUEstimator(use_tpu=FLAGS.use_tpu,
                             model_fn=model_fn,
                             config=run_config,
                             train_batch_size=FLAGS.train_batch_size,
                             eval_batch_size=FLAGS.eval_batch_size,
                             predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        if not tf.gfile.Exists(train_file) or not FLAGS.data_converted:
            processor.file_based_convert_examples_to_features(
                train_examples, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = processor.file_based_input_fn_builder(
            input_file=train_file, is_training=True, drop_remainder=True)
        train_hook = tf.train.LoggingTensorHook(['loss/train_loss'],
                                                every_n_iter=100)
        train_spec = tf.estimator.TrainSpec(train_input_fn, num_train_steps,
                                            [train_hook])

    if FLAGS.do_eval:

        eval_examples = processor.get_dev_examples()
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        if not tf.gfile.Exists(eval_file) or not FLAGS.data_converted:
            processor.file_based_convert_examples_to_features(
                eval_examples, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = processor.file_based_input_fn_builder(
            input_file=eval_file,
            is_training=False,
            drop_remainder=eval_drop_remainder)
        eval_spec = tf.estimator.EvalSpec(eval_input_fn,
                                          steps=None,
                                          throttle_secs=FLAGS.throttle_secs)

    if FLAGS.do_train and FLAGS.do_eval:
        tf.logging.info("***** Running training and evaluation*****")
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    elif FLAGS.do_train:
        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps,
                        hooks=[train_hook])

    elif FLAGS.do_eval:

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        if FLAGS.predict_from_file:
            predict_examples = processor.get_prediction_examples(
                FLAGS.predict_from_file)
        else:
            predict_examples = processor.get_test_examples()
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        if not tf.gfile.Exists(predict_file) or not FLAGS.data_converted:
            processor.file_based_convert_examples_to_features(
                predict_examples, predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = processor.file_based_input_fn_builder(
            input_file=predict_file,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        results = estimator.predict(input_fn=predict_input_fn)

        metrics = [(result['probabilities'], result['input_ids'],
                    result['label_ids'], result['input_mask'])
                   for result in results]
        probabilities, input_ids, label_ids, input_mask = list(zip(*metrics))

        processor.post_process(FLAGS.output_dir, label_ids, probabilities,
                               input_mask, input_ids, 0.5)
Beispiel #10
0
    type=str,
    help='Base filename to use. THIS MUST BE A LOCAL FILE.')
parser.add_argument(
    '-max_seq_length',
    dest='max_seq_length',
    default=1025,
    type=int,
    help='Max sequence length',
)

args = parser.parse_args()
random.seed(args.seed + args.fold)

#encoder = get_encoder()
tokenizer = tokenization.FullTokenizer(
    vocab_file=
    "/Users/zchai/PycharmProjects/gpt2-ml/tokenization/bert-base-chinese-vocab.txt",
    do_lower_case=True)


class S3TFRecordWriter(object):
    def __init__(self, fn):
        self.fn = fn

        self.s3client = None
        self.gclient = None
        self.bucket_name = None
        self.file_name = None
        self.storage_dir = None
        self.writer = tf.io.TFRecordWriter(fn)

    def write(self, x):
Beispiel #11
0
args.min_len=100












proj_root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
vocab_file_path = os.path.join(proj_root_path, "tokenization/clue-vocab.txt")
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file_path , do_lower_case=True)# 这里面有2个tokenize,因为中文一个是标点符号,一个是wordpiece.2中分法都使用才是正确的方式.
news_config = GroverConfig.from_json_file(args.config_fn)

# We might have to split the batch into multiple chunks if the batch size is too large
default_mbs = {12: 32, 24: 16, 48: 3}
max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers]

# factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size
num_chunks = int(np.ceil(args.batch_size / max_batch_size))
batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))

# This controls the top p for each generation.
top_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p

tf_config = tf.ConfigProto(allow_soft_placement=True)
Beispiel #12
0
def predict():

    ##### ignore tf deprecated warning temporarily
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    # mac-specific settings, comment this when exec in other systems
    os.environ['KMP_DUPLICATE_LIB_OK']='True'

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
    from tensorflow.python.util import deprecation
    deprecation._PRINT_DEPRECATION_WARNINGS = False
    try:
        from tensorflow.python.util import module_wrapper as deprecation
    except ImportError:
        from tensorflow.python.util import deprecation_wrapper as deprecation
    deprecation._PER_MODULE_WARNING_LIMIT = 0
    #####

    parser = argparse.ArgumentParser(description='Contextual generation (aka given some metadata we will generate articles')
    parser.add_argument(
        '-metadata_fn',
        dest='metadata_fn',
        type=str,
        help='Path to a JSONL containing metadata',
    )
    parser.add_argument(
        '-out_fn',
        dest='out_fn',
        type=str,
        help='Out jsonl, which will contain the completed jsons',
    )
    parser.add_argument(
        '-input',
        dest='input',
        type=str,
        help='Text to complete',
    )
    parser.add_argument(
        '-model_config_fn',
        dest='model_config_fn',
        default='configs/mega.json',
        type=str,
        help='Configuration JSON for the model',
    )
    parser.add_argument(
        '-model_ckpt',
        dest='model_ckpt',
        default='model.ckpt-220000',
        type=str,
        help='checkpoint file for the model',
    )
    parser.add_argument(
        '-target',
        dest='target',
        default='article',
        type=str,
        help='What to generate for each item in metadata_fn. can be article (body), title, etc.',
    )
    parser.add_argument(
        '-batch_size',
        dest='batch_size',
        default=1,
        type=int,
        help='How many things to generate per context. will split into chunks if need be',
    )
    parser.add_argument(
        '-num_folds',
        dest='num_folds',
        default=1,
        type=int,
        help='Number of folds. useful if we want to split up a big file into multiple jobs.',
    )
    parser.add_argument(
        '-fold',
        dest='fold',
        default=0,
        type=int,
        help='which fold we are on. useful if we want to split up a big file into multiple jobs.'
    )
    parser.add_argument(
        '-max_batch_size',
        dest='max_batch_size',
        default=None,
        type=int,
        help='max batch size. You can leave this out and we will infer one based on the number of hidden layers',
    )
    parser.add_argument(
        '-top_p',
        dest='top_p',
        default=0.95,
        type=float,
        help='p to use for top p sampling. if this isn\'t none, use this for everthing'
    )
    parser.add_argument(
        '-min_len',
        dest='min_len',
        default=1024,
        type=int,
        help='min length of sample',
    )
    parser.add_argument(
        '-eos_token',
        dest='eos_token',
        default=60000,
        type=int,
        help='eos token id',
    )
    parser.add_argument(
        '-samples',
        dest='samples',
        default=5,
        type=int,
        help='num_samples',
    )

    def extract_generated_target(output_tokens, tokenizer):
        """
        Given some tokens that were generated, extract the target
        :param output_tokens: [num_tokens] thing that was generated
        :param encoder: how they were encoded
        :param target: the piece of metadata we wanted to generate!
        :return:
        """
        # Filter out first instance of start token
        assert output_tokens.ndim == 1

        start_ind = 0
        end_ind = output_tokens.shape[0]

        return {
            'extraction': tokenization.printable_text(''.join(tokenizer.convert_ids_to_tokens(output_tokens))),
            'start_ind': start_ind,
            'end_ind': end_ind,
        }

    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    proj_root_path = os.path.dirname(os.path.realpath(__file__))
    vocab_file_path = os.path.join(proj_root_path, "tokenization/clue-vocab.txt")

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file_path , do_lower_case=True)
    news_config = GroverConfig.from_json_file(args.model_config_fn)

    # We might have to split the batch into multiple chunks if the batch size is too large
    default_mbs = {12: 32, 24: 16, 48: 3}
    max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers]

    # factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size
    num_chunks = int(np.ceil(args.batch_size / max_batch_size))
    batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))

    # This controls the top p for each generation.
    top_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p

    tf_config = tf.ConfigProto(allow_soft_placement=True)

    with tf.Session(config=tf_config, graph=tf.Graph()) as sess:
        initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None])
        p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk])
        eos_token = tf.placeholder(tf.int32, [])
        min_len = tf.placeholder(tf.int32, [])
        tokens, probs = sample(news_config=news_config, initial_context=initial_context,
                            eos_token=eos_token, min_len=min_len, ignore_ids=None, p_for_topp=p_for_topp,
                            do_topk=False)

        saver = tf.train.Saver()
        saver.restore(sess, args.model_ckpt)

        '''
        如果部署到web上,则所有的print都不需要
        input改为web返回的message
        不需要while循环
        将最后的"\n".join(l) 返回到一个参数,并展示到web中
        主要参数(篇数、长度)要用户在web中输入,或者在本代码里写死 -- 有默认值

        待解决:
        sample有5个,下面代码会for循环分别predict 5次,这5次结果要怎么在网页展示?
        min_lens没有用,比如1024的时候还是会生产一两百字的文章

        '''

        # print('🍺Model loaded. \nInput something please:⬇️')

        if request.method == 'POST':
            text = request.form['message']
            # data = [text] 原spam detection里的代码,不确定此处是否需要
        
        for i in range(args.samples):
            # print("Sample,", i + 1, " of ", args.samples)
            line = tokenization.convert_to_unicode(text)
            bert_tokens = tokenizer.tokenize(line)
            encoded = tokenizer.convert_tokens_to_ids(bert_tokens)
            context_formatted = []
            context_formatted.extend(encoded)
            # Format context end

            gens = []
            gens_raw = []
            gen_probs = []
            final_result = []

            for chunk_i in range(num_chunks):
                tokens_out, probs_out = sess.run([tokens, probs],
                                                feed_dict={initial_context: [context_formatted] * batch_size_per_chunk,
                                                            eos_token: args.eos_token, min_len: args.min_len,
                                                            p_for_topp: top_p[chunk_i]})

                for t_i, p_i in zip(tokens_out, probs_out):
                    extraction = extract_generated_target(output_tokens=t_i, tokenizer=tokenizer)
                    gens.append(extraction['extraction'])

            l = re.findall('.{1,70}', gens[0].replace('[UNK]', '').replace('##', ''))
            # 下一句的参应该传给 return
            # print("\n".join(l)) 
            # return a for loop
            # https://stackoverflow.com/questions/44564414/how-to-use-a-return-statement-in-a-for-loop
            final_result.append("\n".join(l))
            

    return render_template('result.html',prediction = final_result)
                    default='realnews.jsonl',
                    type=str,
                    help='Base filename to use. THIS MUST BE A LOCAL FILE.')
parser.add_argument(
    '-max_seq_length',
    dest='max_seq_length',
    default=1024,
    type=int,
    help='Max sequence length',
)

args = parser.parse_args()
random.seed(args.seed + args.fold)

tokenizer = tokenization.FullTokenizer(
    vocab_file="/data/home/share1/gpt2-ml/dataset/bert-base-chinese-vocab.txt",
    do_lower_case=True)


class S3TFRecordWriter(object):
    def __init__(self, fn):
        self.fn = fn
        if fn.startswith('s3://'):
            from boto3.s3.transfer import TransferConfig
            import boto3
            self.gclient = None
            self.s3client = boto3.client('s3', )
            self.storage_dir = TemporaryDirectory()
            self.writer = tf.python_io.TFRecordWriter(
                os.path.join(self.storage_dir.name, 'temp.tfrecord'))
            self.bucket_name, self.file_name = self.fn.split('s3://',