Ejemplo n.º 1
0
def main(unused_argv):
    del unused_argv

    tf.logging.set_verbosity(tf.logging.INFO)

    csv_logger, csv_file = get_csv_logger()
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.logging.info("n_token: {}".format(n_token))

    min_valid_loss = np.inf

    for epoch in range(FLAGS.epochs):
        train_epoch(epoch, csv_logger, n_token, cutoffs)

        tf.logging.info("Evaluating epoch {}".format(epoch))
        valid_log_dict = evaluate(n_token, cutoffs)
        valid_log_dict['epoch'] = epoch
        csv_logger.writerow(valid_log_dict)
        
        if valid_log_dict['valid_loss'] < min_valid_loss:
            min_valid_loss = valid_log_dict['valid_loss']
            tf.logging.info("New min loss {}".format(min_valid_loss))
            save_path = os.path.join(FLAGS.model_dir, "best_model.ckpt")
            tf.logging.info("Saving model in {}".format(save_path))
            saver.save(sess, save_path)
            
    csv_file.close()
Ejemplo n.º 2
0
def main(unused_argv):
  rank, local_rank, size = 0, 0, 1
  if FLAGS.horovod:
    hvd.init()
    rank = hvd.rank()
    local_rank = hvd.local_rank()
    size = hvd.size()
  del unused_argv  # Unused

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.fp16:
      os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
  else:
      os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "0"

  # Get corpus info
  corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
  n_token = corpus_info["vocab_size"]
  cutoffs = corpus_info["cutoffs"][1:-1]
  tf.logging.info("n_token {}".format(n_token))

  setup_dllogger(enabled=True, filename=FLAGS.raport_file, rank=rank)

  if FLAGS.do_train:
    train(n_token, cutoffs, rank, local_rank, size)
  if FLAGS.do_eval:
    evaluate(n_token, cutoffs)
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.logging.info("n_token {}".format(n_token))

    dynamic_eval(n_token, cutoffs, "/gpu:0")
Ejemplo n.º 4
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.compat.v1.logging.info("n_token {}".format(n_token))

    if FLAGS.do_train:
        train(n_token, cutoffs, "/gpu:0")
    if FLAGS.do_eval:
        evaluate(n_token, cutoffs, "/gpu:0")
    if FLAGS.do_inference:
        inference(n_token, cutoffs, "/gpu:0")
Ejemplo n.º 5
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]

    if FLAGS.save_steps == 0:
        FLAGS.save_steps = None

    if not FLAGS.do_eval_only:
        # Get train input function
        train_input_fn, train_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts,
            use_tpu=FLAGS.use_tpu)
        train_bin_sizes = train_record_info["bin_sizes"]
        num_train_batch = train_record_info["num_batch"]

        # Get train cache function
        train_cache_fn = get_cache_fn(FLAGS.mem_len)
    else:
        train_bin_sizes = []
        num_train_batch = None
        train_cache_fn = None

    if FLAGS.do_eval or FLAGS.do_eval_only:
        assert FLAGS.num_hosts == 1
        # Get eval input function
        eval_input_fn, eval_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split=FLAGS.eval_split,
            per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts,
            use_tpu=FLAGS.use_tpu)
        eval_bin_sizes = eval_record_info["bin_sizes"]
        num_eval_batch = eval_record_info["num_batch"]

        if FLAGS.max_eval_batch > 0:
            num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch)

        # Get eval cache function
        eval_cache_fn = get_cache_fn(FLAGS.mem_len)
        model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes,
                                eval_bin_sizes)
    else:
        eval_cache_fn = None
        model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, [])

    ##### Create estimator
    # TPU Configuration
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations,
            num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts,
            per_host_input_for_training=per_host_input),
        keep_checkpoint_max=100000,  # effectively save all checkpoints
        save_checkpoints_secs=None,
        save_checkpoints_steps=FLAGS.save_steps)

    # warm start
    warm_start_from = None
    if FLAGS.warm_start_path is not None:
        warm_start_from = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=FLAGS.warm_start_path)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        eval_cache_fn=eval_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params={
            "data_dir": FLAGS.data_dir,
            "track_mean": FLAGS.track_mean
        },
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        warm_start_from=warm_start_from)

    if FLAGS.do_eval_only:
        if FLAGS.eval_ckpt_path is not None:
            ret = estimator.evaluate(input_fn=eval_input_fn,
                                     steps=num_eval_batch,
                                     checkpoint_path=FLAGS.eval_ckpt_path)
            tf.logging.info("=" * 200)
            log_str = "Eval results | "
            for key, val in ret.items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
        else:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
            eval_results = []
            for eval_checkpoint in ckpt_state.all_model_checkpoint_paths:
                if not exists(eval_checkpoint + ".index"): continue
                global_step = int(eval_checkpoint.split("-")[-1])
                if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps:
                    continue
                ret = estimator.evaluate(input_fn=eval_input_fn,
                                         steps=num_eval_batch,
                                         checkpoint_path=eval_checkpoint)
                eval_results.append(ret)

            eval_results.sort(key=lambda x: x["perplexity"])

            tf.logging.info("=" * 200)
            log_str = "Best results | "
            for key, val in eval_results[0].items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
    else:
        if not FLAGS.do_eval:
            estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)
        else:
            for step in range(0, FLAGS.train_steps, num_train_batch):
                train_steps = min(FLAGS.train_steps - step, num_train_batch)
                estimator.train(input_fn=train_input_fn, steps=train_steps)
                estimator.evaluate(input_fn=eval_input_fn,
                                   steps=num_eval_batch)
Ejemplo n.º 6
0
def main(unused_argv):
    del unused_argv  # Unused

    # Currently implemented for only one host
    assert (FLAGS.num_hosts == 1)

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.compat.v1.logging.info("n_token {}".format(n_token))

    if FLAGS.do_train:
        # Get train input function
        train_data = data_utils.Dataset(
            data_dir=FLAGS.data_dir,
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts)

    if FLAGS.do_eval or FLAGS.do_test:
        # Get valid input function
        valid_data = data_utils.Dataset(
            data_dir=FLAGS.data_dir,
            record_info_dir=FLAGS.record_info_dir,
            split="valid",
            per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts)

        test_data = data_utils.Dataset(
            data_dir=FLAGS.data_dir,
            record_info_dir=FLAGS.record_info_dir,
            split="test",
            per_host_bsz=FLAGS.test_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.test_tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts)
    else:
        valid_data = None
        test_data = None

    if FLAGS.use_tpu:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
    else:
        strategy = tf.distribute.get_strategy()
    print("Number of accelerators: ", strategy.num_replicas_in_sync)

    # Ensure that number of replicas in sync is same as 'FLAGS.num_core_per_host'
    assert (FLAGS.num_core_per_host == strategy.num_replicas_in_sync)

    chk_name = 'texl_chk'
    if FLAGS.do_train:
        train(n_token, cutoffs, train_data, valid_data, test_data, strategy,
              chk_name)
    if FLAGS.do_test:
        evaluate(n_token, cutoffs, valid_data, test_data, strategy, chk_name)
Ejemplo n.º 7
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.logging.info("n_token {}".format(n_token))

    tmp_Vocab = Vocab(special=["<bos>", "<eos>", "<UNK>"])
    tmp_Vocab.count_file("../data/{}/train.txt".format(FLAGS.dataset),
                         add_eos=False)
    tmp_Vocab.build_vocab()

    if FLAGS.do_sent_ppl_pred:
        encoded_txt_input = []
        txt_input = []
        input_csv = []
        with open(FLAGS.input_file_dir, "r") as read_file:
            csv_reader = csv.reader(read_file)
            for line in csv_reader:
                if line[0].strip() != 0:
                    input_csv.append(line)

            for i in range(1, len(input_csv)):
                txt_input.append(input_csv[i][0].strip())
                encoded_txt_input.append(list(tmp_Vocab.encode_sents(input_csv[i][0].strip(), \
                    add_eos=True, ordered=True)))

        encoded_txt_input = [
            line[:FLAGS.limit_len] if len(line) > FLAGS.limit_len else line
            for line in encoded_txt_input
        ]
        encoded_txt_input = np.array(encoded_txt_input)

        input_csv[0].append("ppl")

        pool = multiprocessing.Pool(FLAGS.multiprocess)

        parti_len = len(encoded_txt_input) // FLAGS.multiprocess
        pro_res_l = []

        for i in range(FLAGS.multiprocess):
            print("Setting process-%s" % i)
            ### 有空这里要写一个控制使用gpu:xx的步骤(gpu:1满了就用下一个)

            if i + 1 == FLAGS.multiprocess:
                end = len(encoded_txt_input)
            else:
                end = (i + 1) * parti_len
            pro_res_l.append(pool.apply_async(sent_ppl, \
                args=(encoded_txt_input[i*parti_len:end], n_token, cutoffs, "/gpu:1")))

        res_l = []

        for i in range(len(pro_res_l)):
            proc_i_res = pro_res_l[i].get()
            res_l.extend(proc_i_res)

        pool.close()
        pool.join()
        print('All subprocesses done.')

        tf.logging.info('#time: {}'.format(time.time()))

        for i in range(1, len(input_csv)):
            input_csv[i].append(res_l[i - 1])
        output_df = pd.DataFrame(input_csv[1:], columns=input_csv[0])
        output_df.to_csv(FLAGS.output_file_dir,
                         sep=",",
                         index=False,
                         encoding="utf-8-sig")

        with open("non_batch_ref_output.txt", "w") as write_res:
            for i in range(len(txt_input)):
                write_res.write(txt_input[i] + " " +
                                str(encoded_txt_input[i]) + " " +
                                str(res_l[i]) + "\n")

        # Check whether the length of result is right; Make sure multiprocess work well
        print(len(res_l))

    elif FLAGS.do_sent_gen:
        txt_gen_list = []
        with open(FLAGS.input_txt_dir, "r") as read_txt:
            for input_txt in read_txt:
                if len(input_txt.strip()) != 0:
                    txt_gen_list.append(
                        sent_gen(tmp_Vocab, input_txt.strip(), n_token,
                                 cutoffs, "/gpu:1"))

        with open("sent_generation.txt", "w") as write_res:
            for line in txt_gen_list:
                write_res.write(line + "\n")
Ejemplo n.º 8
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.logging.info("n_token {}".format(n_token))

    tmp_Vocab = Vocab(special=["<bos>", "<eos>", "<UNK>"])
    tmp_Vocab.count_file("../data/{}/train.txt".format(FLAGS.dataset),
                         add_eos=False)
    tmp_Vocab.build_vocab()

    if FLAGS.do_sent_ppl_pred:
        encoded_txt_input = []
        txt_input = []
        input_csv = []
        with open(FLAGS.input_file_dir, "r") as read_file:
            csv_reader = csv.reader(read_file)
            for line in csv_reader:
                if line[0].strip() != 0:
                    input_csv.append(line)

            for i in range(1, len(input_csv)):
                txt_input.append(input_csv[i][0].strip())
                encoded_txt_input.append(list(tmp_Vocab.encode_sents(input_csv[i][0].strip(), \
                    add_eos=True, ordered=True)))

        #txt_input = txt_input[:7]
        #encoded_txt_input = encoded_txt_input[:7]

        encoded_txt_input_len = [len(encoded_txt) if len(encoded_txt) <= FLAGS.limit_len else FLAGS.limit_len \
             for encoded_txt in encoded_txt_input]

        encoded_txt_input = [
            cut_pad(line, FLAGS.limit_len, 0) for line in encoded_txt_input
        ]

        if len(encoded_txt_input) % FLAGS.pred_batch_size != 0:
            pad_len = FLAGS.pred_batch_size - (len(encoded_txt_input) %
                                               FLAGS.pred_batch_size)
            encoded_txt_input = encoded_txt_input + \
                [[0]*FLAGS.limit_len]*pad_len
            encoded_txt_input_len = encoded_txt_input_len + \
                [FLAGS.limit_len]*pad_len

        encoded_txt_input = np.array(encoded_txt_input).reshape(
            FLAGS.pred_batch_size, -1, FLAGS.limit_len)
        encoded_txt_input_len = np.array(encoded_txt_input_len).reshape(
            FLAGS.pred_batch_size, -1)
        input_csv[0].append("ppl")

        if FLAGS.multiprocess == 1 or encoded_txt_input.shape[
                1] // FLAGS.multiprocess == 0:
            ppl_list = sent_ppl((encoded_txt_input, encoded_txt_input_len),
                                n_token, cutoffs, "/gpu:1")

            for i in range(1, len(input_csv)):
                input_csv[i].append(ppl_list[i - 1])
            output_df = pd.DataFrame(input_csv[1:], columns=input_csv[0])
            output_df.to_csv(FLAGS.output_file_dir,
                             sep=",",
                             index=False,
                             encoding="utf-8-sig")

            with open("sent_ppl_pred.txt", "w") as write_res:
                for i in range(len(txt_input)):
                    write_res.write(txt_input[i] + "\t" + str(ppl_list[i]) +
                                    "\n")

            # Check whether the length of result is right; Make sure batch-predict work well
            print(len(ppl_list))
        else:
            pool = multiprocessing.Pool(FLAGS.multiprocess)
            parti_batch_num = encoded_txt_input.shape[1] // FLAGS.multiprocess
            pro_res_l = []

            for i in range(FLAGS.multiprocess):
                print("Setting process-%s" % i)
                ### 有空这里要写一个控制使用gpu:xx的步骤(gpu:1满了就用下一个)

                if i + 1 == FLAGS.multiprocess:
                    end = encoded_txt_input.shape[1]
                else:
                    end = (i + 1) * parti_batch_num
                pro_res_l.append(pool.apply_async(sent_ppl, \
                    args=((encoded_txt_input[:,i*parti_batch_num:end,:], \
                        encoded_txt_input_len[:,i*parti_batch_num:end]), \
                        n_token, cutoffs, "/gpu:1")))

            res_l = [[] for _ in range(FLAGS.pred_batch_size)]

            for i in range(len(pro_res_l)):
                proc_i_res = pro_res_l[i].get()
                parti_len = len(proc_i_res) // FLAGS.pred_batch_size
                for j in range(FLAGS.pred_batch_size):
                    res_l[j].extend(proc_i_res[j * parti_len:(j + 1) *
                                               parti_len])

            pool.close()
            pool.join()
            print('All subprocesses done.')

            res_merge = []
            for i in range(FLAGS.pred_batch_size):
                res_merge.extend(res_l[i])
            tf.logging.info('#time: {}'.format(time.time()))

            for i in range(1, len(input_csv)):
                input_csv[i].append(res_merge[i - 1])
            output_df = pd.DataFrame(input_csv[1:], columns=input_csv[0])
            output_df.to_csv(FLAGS.output_file_dir,
                             sep=",",
                             index=False,
                             encoding="utf-8-sig")

            with open("sent_ppl_pred.txt", "w") as write_res:
                for i in range(len(txt_input)):
                    write_res.write(txt_input[i] + "\t" + str(res_merge[i]) +
                                    "\n")

            # Check whether the length of result is right; Make sure multiprocess work well
            print(len(res_merge))
Ejemplo n.º 9
0
    os.remove("temp_input_text.txt")

    input_text = process_text(original_text, process=True)
    output_text = process_text(output_text, process=True)
    return bottle.template(output_page,
                           input_text=original_text,
                           output_text=output_text)


def main(unused_argv):
    del unused_argv  # Unused
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
    # chech: http://localhost:8787/transformer
    bottle.run(host='0.0.0.0', port=8787, reloader=True)


print("Loading Spacy...")
nlp = spacy.load('en_core_web_lg')
print("Loading corpus info...")
corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
n_token = corpus_info["vocab_size"]
cutoffs = corpus_info["cutoffs"][1:-1]
test_tgt_len = FLAGS.tgt_len
vocab_path = os.path.join(FLAGS.data_dir, "cache_vocab.pkl")
print("Loading cached vocabulary...")
with open(vocab_path, "rb") as fp:
    vocab = pickle.load(fp)

tf.compat.v1.app.run()
Ejemplo n.º 10
0
def main(unused_argv):
    """ Load sentences and pre-trained model, output representations. """
    del unused_argv  # Unused
    tf.logging.set_verbosity(tf.logging.INFO)

    corpus_info = get_corpus_info('{}/corpus-info.json'.format(FLAGS.data_dir))
    n_token = corpus_info["vocab_size"]
    print(n_token)
    cutoffs = corpus_info["cutoffs"][1:-1]

    sentences = load_dataset()
    eval_dataset = eval_input_fn(sentences)
    input_feed, label_feed = eval_dataset.make_one_shot_iterator().get_next()

    # Build the computations graph.
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

        mems = [
            tf.placeholder(tf.float32, [FLAGS.mem_len, 1, FLAGS.d_model])
            for _ in range(FLAGS.n_layer)
        ]

        loss, new_mem, outputs = single_core_graph(n_token=n_token,
                                                   cutoffs=cutoffs,
                                                   is_training=False,
                                                   inp=input_feed,
                                                   tgt=label_feed,
                                                   mems=mems)

    saver = tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, FLAGS.model_checkpoint)

        sentence_representations = []
        # iterate over sentences
        for sentence in sentences:
            char_reps_np = None
            tower_mems_np = \
                    [np.zeros([FLAGS.mem_len, 1, FLAGS.d_model], dtype=np.float32)
                     for layer in range(FLAGS.n_layer)]

            # iterate over paritions
            for _ in sentence:
                fetches = [loss, new_mem, outputs]
                feed_dict = {}
                for m_ref, m_np in zip(mems, tower_mems_np):
                    feed_dict[m_ref] = m_np

                # run the graph on our next input, store new memory and reps
                fetched = sess.run(fetches, feed_dict=feed_dict)
                _, tower_mems_np, char_rep = fetched[:3]

                # concat the partition back into the sentence
                char_rep = np.squeeze(char_rep, axis=1)
                if char_reps_np is None:
                    char_reps_np = char_rep
                else:
                    char_reps_np = np.concatenate((char_reps_np, char_rep),
                                                  axis=0)

            if FLAGS.backwards:
                char_reps_np = np.flip(char_reps_np, axis=0)

            sentence_representations.append(char_reps_np)

    tf.logging.info("Extracted features for {} sentences.".format(
        len(sentence_representations)))
    tf.logging.info("Saving the representations here: {}".format(
        FLAGS.sentence_reps_out))
    np.save(FLAGS.sentence_reps_out, sentence_representations)
Ejemplo n.º 11
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    # tf.compat.v1.logging.info("n_token {}".format(n_token))
    # tf.compat.v1.logging.info("cutoffs {}".format(cutoffs))

    if FLAGS.do_train:
        train(n_token, cutoffs, "/gpu:0")
    if FLAGS.do_eval:
        evaluate(n_token, cutoffs, "/gpu:0")
    if FLAGS.do_generate:

        def process_text(text, process=False):
            if process:
                text = text.replace(u" , ", u", ")
                text = text.replace(u" . ", u". ")
                text = text.replace(u" ; ", u"; ")
                text = text.replace(u" : ", u": ")
                text = text.replace(u" ( ", u" (")
                text = text.replace(u" ) ", u") ")
                text = text.replace(u"<eos>", u" ")
                text = text.replace(u" 's ", u"'s ")
                while text.count(u"  ") > 0:
                    text = text.replace(u"  ", u" ")
            return text

        print("Loading Spacy...")
        nlp = spacy.load('en_core_web_lg')
        test_tgt_len = FLAGS.tgt_len
        vocab_path = os.path.join(FLAGS.data_dir, "cache_vocab.pkl")
        print("Loading cached vocabulary...")
        with open(vocab_path, "rb") as fp:
            vocab = pickle.load(fp)
        while True:
            print("Enter initial text or 0 for exit:")
            text = input()
            if text == "0":
                break
            else:
                print("Enter output text length:")
                text_length = abs(int(input()))

                print("Enter randomness [1-100]:")
                word_range = abs(int(input()))
                if word_range == 0:
                    word_range = 1

                print("Processing text...")
                doc = nlp(text)
                text_out = []
                for sent in doc.sents:
                    for token in sent:
                        text_out.append(token.text)
                if len(text_out) < test_tgt_len:
                    text_out = ["<unk>"
                                ] * (test_tgt_len - len(text_out)) + text_out
                text_out = u" " + ' '.join(text_out)
                while text_out.count(u"  ") > 0:
                    text_out = text_out.replace(u"  ", u" ")

                file_o = io.open("temp_input_text.txt", "w", encoding='utf8')
                file_o.write(text_out)
                file_o.close()

                input_text, output_text = generate_text(
                    n_token,
                    cutoffs,
                    "/gpu:0",
                    vocab,
                    text_length=text_length,
                    word_range=word_range)

                print("========== Start text ==========")
                print(process_text(text, process=True))
                print("========== Generated text ==========")
                print(process_text(output_text, process=True))

                os.remove("temp_input_text.txt")

    if FLAGS.do_experiment:

        def remove_double_spaces(text):
            while text.count(u"  ") > 0:
                text = text.replace(u"  ", u" ")
            return text

        def process_text(text, process=False):
            if process:
                text = text.replace(u" , ", u", ")
                text = text.replace(u" . ", u". ")
                text = text.replace(u" ; ", u"; ")
                text = text.replace(u" : ", u": ")
                text = text.replace(u" ( ", u" (")
                text = text.replace(u" ) ", u") ")
                text = text.replace(u"<eos>", u" ")
                text = text.replace(u" 's ", u"'s ")
                text = remove_double_spaces(text)
            return text

        def process_text_2(text, process=False):
            if process:
                text = text.replace(u"<eos>", u" ")
                text = text.replace(u"<unk>", u" ")
                text = remove_double_spaces(text)
            return text

        print("Loading Spacy...")
        nlp = spacy.load('en_core_web_lg')
        test_tgt_len = FLAGS.tgt_len
        vocab_path = os.path.join(FLAGS.data_dir, "cache_vocab.pkl")
        print("Loading cached vocabulary...")
        with open(vocab_path, "rb") as fp:
            vocab = pickle.load(fp)

        output_dict = {
            "beginning": [],
            "true_end": [],
            "generated_end": [],
            "generated_end_full": []
        }

        input_file = pd.read_csv('data/output.csv')

        experiments = input_file.values
        n_exp = len(experiments)
        experiment_numbers = np.random.permutation(n_exp)

        iteration_number = 0
        for i in experiment_numbers:

            iteration_number = iteration_number + 1
            print("Experiment: {}".format(iteration_number))
            tf.compat.v1.logging.info(
                "Experiment: {}".format(iteration_number))

            text_beginning = experiments[i][0]
            text_true_end = experiments[i][1]
            text = remove_double_spaces(text_beginning)

            # output text length
            text_length = 200

            # randomness [1-100]
            word_range = 3

            doc = nlp(text)
            text_out = []
            for sent in doc.sents:
                for token in sent:
                    text_out.append(token.text)
            if len(text_out) < test_tgt_len:
                text_out = ["<unk>"
                            ] * (test_tgt_len - len(text_out)) + text_out
            text_out = u" " + ' '.join(text_out)
            text_out = remove_double_spaces(text_out)

            file_o = io.open("temp_input_text.txt", "w", encoding='utf8')
            file_o.write(text_out)
            file_o.close()

            input_text, output_text = generate_text(n_token,
                                                    cutoffs,
                                                    "/gpu:0",
                                                    vocab,
                                                    text_length=text_length,
                                                    word_range=word_range)
            output_text = process_text_2(output_text, process=True)

            output_dict["beginning"].append(text_beginning)
            output_dict["true_end"].append(text_true_end)
            output_dict["generated_end_full"].append(output_text)

            output_doc = nlp(output_text)
            output_text_trimmed = []
            NUMBER_OF_SENTENCES = 3
            sentence_num = 0
            for sent in output_doc.sents:
                for token in sent:
                    output_text_trimmed.append(token.text)
                sentence_num = sentence_num + 1
                if sentence_num >= NUMBER_OF_SENTENCES:
                    break
            output_text_trimmed = u" " + ' '.join(output_text_trimmed)
            output_text_trimmed = remove_double_spaces(output_text_trimmed)

            output_dict["generated_end"].append(output_text_trimmed)

            print("========== Start text ==========")
            print(process_text(text, process=True))
            print("========== Generated text ==========")
            print(process_text(output_text_trimmed, process=True))

            os.remove("temp_input_text.txt")

            if iteration_number % (n_exp / 10) == 0:
                output_df_full = pd.DataFrame(output_dict,
                                              columns=[
                                                  "beginning", "true_end",
                                                  "generated_end",
                                                  "generated_end_full"
                                              ])
                output_df_full.to_csv('data/output_full.csv',
                                      index=False,
                                      header=True)

                output_df_trimmed = pd.DataFrame(
                    output_dict,
                    columns=["beginning", "true_end", "generated_end"])
                output_df_trimmed.to_csv('data/output_trimmed.csv',
                                         index=False,
                                         header=True)

        output_df_full = pd.DataFrame(output_dict,
                                      columns=[
                                          "beginning", "true_end",
                                          "generated_end", "generated_end_full"
                                      ])
        output_df_full.to_csv('data/output_full.csv', index=False, header=True)

        output_df_trimmed = pd.DataFrame(
            output_dict, columns=["beginning", "true_end", "generated_end"])
        output_df_trimmed.to_csv('data/output_trimmed.csv',
                                 index=False,
                                 header=True)