Exemple #1
0
def train():
    """Train a en->fr translation model using WMT data."""
    #with tf.device("/gpu:0"):
    # Prepare WMT data.
    train_path = os.path.join(FLAGS.data_dir, "chitchat.train")
    fixed_path = os.path.join(FLAGS.data_dir, "chitchat.fixed")
    weibo_path = os.path.join(FLAGS.data_dir, "chitchat.weibo")
    qa_path = os.path.join(FLAGS.data_dir, "chitchat.qa")

    voc_file_path = [
        train_path + ".answer", fixed_path + ".answer", weibo_path + ".answer",
        qa_path + ".answer", train_path + ".query", fixed_path + ".query",
        weibo_path + ".query", qa_path + ".query"
    ]

    vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.all" % FLAGS.vocab_size)

    data_utils.create_vocabulary(vocab_path, voc_file_path, FLAGS.vocab_size)

    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    print("Preparing Chitchat data in %s" % FLAGS.data_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        FLAGS.data_dir, vocab, FLAGS.vocab_size)

    print("Preparing Fixed data in %s" % FLAGS.fixed_set_path)
    fixed_path = os.path.join(FLAGS.fixed_set_path, "chitchat.fixed")
    fixed_query, fixed_answer = data_utils.prepare_defined_data(
        fixed_path, vocab, FLAGS.vocab_size)

    print("Preparing Weibo data in %s" % FLAGS.weibo_set_path)
    weibo_path = os.path.join(FLAGS.weibo_set_path, "chitchat.weibo")
    weibo_query, weibo_answer = data_utils.prepare_defined_data(
        weibo_path, vocab, FLAGS.vocab_size)

    print("Preparing QA data in %s" % FLAGS.qa_set_path)
    qa_path = os.path.join(FLAGS.qa_set_path, "chitchat.qa")
    qa_query, qa_answer = data_utils.prepare_defined_data(
        qa_path, vocab, FLAGS.vocab_size)

    dummy_path = os.path.join(FLAGS.data_dir, "chitchat.dummy")
    dummy_set = data_utils.get_dummy_set(dummy_path, vocab, FLAGS.vocab_size)
    print("Get Dummy Set : ", dummy_set)

    with tf.Session() as sess:
        #with tf.device("/gpu:1"):
        # Create model.
        print("Creating %d layers of %d units." %
              (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, dummy_set, False)

        # Read data into buckets and compute their sizes.
        print("Reading development and training data (limit: %d)." %
              FLAGS.max_train_data_size)
        dev_set = read_data(dev_query, dev_answer)
        train_set = read_data(train_query, train_answer,
                              FLAGS.max_train_data_size)
        fixed_set = read_data(fixed_query, fixed_answer,
                              FLAGS.max_train_data_size)
        weibo_set = read_data(weibo_query, weibo_answer,
                              FLAGS.max_train_data_size)
        qa_set = read_data(qa_query, qa_answer, FLAGS.max_train_data_size)

        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []

        en_dict_cover = {}
        fr_dict_cover = {}
        if model.global_step.eval() > FLAGS.steps_per_checkpoint:
            try:
                with open(FLAGS.en_cover_dict_path, "rb") as ef:
                    en_dict_cover = pickle.load(ef)
                    # for line in ef.readlines():
                    #     line = line.strip()
                    #     key, value = line.strip(",")
                    #     en_dict_cover[int(key)]=int(value)
            except Exception:
                print("no find query_cover_file")
            try:
                with open(FLAGS.ff_cover_dict_path, "rb") as ff:
                    fr_dict_cover = pickle.load(ff)
                    # for line in ff.readlines():
                    #     line = line.strip()
                    #     key, value = line.strip(",")
                    #     fr_dict_cover[int(key)]=int(value)
            except Exception:
                print("no find answer_cover_file")

        step_loss_summary = tf.Summary()
        #merge = tf.merge_all_summaries()
        writer = tf.summary.FileWriter("../logs/", sess.graph)

        while True:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch(
                train_set, bucket_id, 0, fixed_set, weibo_set, qa_set)

            if FLAGS.reinforce_learning:
                _, step_loss, _ = model.step_rl(sess, _buckets, encoder_inputs,
                                                decoder_inputs, target_weights,
                                                batch_source_encoder,
                                                batch_source_decoder,
                                                bucket_id)
            else:
                _, step_loss, _ = model.step(sess,
                                             encoder_inputs,
                                             decoder_inputs,
                                             target_weights,
                                             bucket_id,
                                             forward_only=False,
                                             force_dec_input=True)

            step_time += (time.time() -
                          start_time) / FLAGS.steps_per_checkpoint
            loss += step_loss / FLAGS.steps_per_checkpoint
            current_step += 1

            query_size, answer_size = _buckets[bucket_id]
            for batch_index in xrange(FLAGS.batch_size):
                for query_index in xrange(query_size):
                    query_word = encoder_inputs[query_index][batch_index]
                    if en_dict_cover.has_key(query_word):
                        en_dict_cover[query_word] += 1
                    else:
                        en_dict_cover[query_word] = 0

                for answer_index in xrange(answer_size):
                    answer_word = decoder_inputs[answer_index][batch_index]
                    if fr_dict_cover.has_key(answer_word):
                        fr_dict_cover[answer_word] += 1
                    else:
                        fr_dict_cover[answer_word] = 0

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % FLAGS.steps_per_checkpoint == 0:

                bucket_value = step_loss_summary.value.add()
                bucket_value.tag = "loss"
                bucket_value.simple_value = float(loss)
                writer.add_summary(step_loss_summary, current_step)

                print("query_dict_cover_num: %s" %
                      (str(en_dict_cover.__len__())))
                print("answer_dict_cover_num: %s" %
                      (str(fr_dict_cover.__len__())))

                ef = open(FLAGS.en_cover_dict_path, "wb")
                pickle.dump(en_dict_cover, ef)
                ff = open(FLAGS.ff_cover_dict_path, "wb")
                pickle.dump(fr_dict_cover, ff)

                # Print statistics for the previous epoch.
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                print(
                    "global step %d learning rate %.4f step-time %.2f perplexity "
                    "%.2f" %
                    (model.global_step.eval(), model.learning_rate.eval(),
                     step_time, perplexity))
                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(
                        previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                # Save checkpoint and zero timer and loss.
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               "chitchat.model")
                model.saver.save(sess,
                                 checkpoint_path,
                                 global_step=model.global_step)
                step_time, loss = 0.0, 0.0
                # Run evals on development set and print their perplexity.
                # for bucket_id in xrange(len(_buckets)):
                #   encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                #       dev_set, bucket_id)
                #   _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                #                                target_weights, bucket_id, True)
                #   eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
                #   print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
                sys.stdout.flush()
Exemple #2
0
def train():
    """Train a en->fr translation model using WMT data."""
    #with tf.device("/gpu:0"):
    # Prepare WMT data.
    train_path = os.path.join(FLAGS.data_dir, "weibo")
    fixed_path = os.path.join(FLAGS.data_dir, "fixed")
    weibo_path = os.path.join(FLAGS.data_dir, "wb")
    qa_path = os.path.join(FLAGS.data_dir, "qa")

    voc_file_path = [
        train_path + ".answer", fixed_path + ".answer", weibo_path + ".answer",
        qa_path + ".answer", train_path + ".query", fixed_path + ".query",
        weibo_path + ".query", qa_path + ".query"
    ]

    vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.txt" % FLAGS.vocab_size)

    data_utils.create_vocabulary(vocab_path, voc_file_path, FLAGS.vocab_size)

    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    print(len(vocab))
    print("Preparing Chitchat data in %s" % FLAGS.data_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        FLAGS.data_dir, vocab, FLAGS.vocab_size)

    print("Preparing Fixed data in %s" % FLAGS.fixed_set_path)
    fixed_path = os.path.join(FLAGS.fixed_set_path, "wb")
    fixed_query, fixed_answer = data_utils.prepare_defined_data(
        fixed_path, vocab, FLAGS.vocab_size)

    print("Preparing Weibo data in %s" % FLAGS.weibo_set_path)
    weibo_path = os.path.join(FLAGS.weibo_set_path, "wb")
    weibo_query, weibo_answer = data_utils.prepare_defined_data(
        weibo_path, vocab, FLAGS.vocab_size)

    print("Preparing QA data in %s" % FLAGS.qa_set_path)
    qa_path = os.path.join(FLAGS.qa_set_path, "wb")
    qa_query, qa_answer = data_utils.prepare_defined_data(
        qa_path, vocab, FLAGS.vocab_size)

    dummy_path = os.path.join(FLAGS.data_dir, "dummy")
    dummy_set = data_utils.get_dummy_set(dummy_path, vocab, FLAGS.vocab_size)
    print("Get Dummy Set : ", dummy_set)
    if FLAGS.reinforce_learning == True and FLAGS.dual_learning == False:
        import data0_utils as du
        config = {}
        config['fill_word'] = du._PAD_
        config['embedding'] = du.embedding
        config['fold'] = 1
        config['model_file'] = "model_mp"
        config['log_file'] = "dis.log"
        config['train_iters'] = 50000
        config['model_tag'] = "mxnet"
        config['batch_size'] = 64
        config['data1_maxlen'] = 46
        config['data2_maxlen'] = 74
        config['data1_psize'] = 5
        config['data2_psize'] = 5
        from importlib import import_module
        mo = import_module(config['model_file'])
        disModel = mo.Model(config)
        disSess = tf.Session()
        disModel.init_step(disSess)
        if sys.argv[1] != "no":
            disModel.saver.restore(disSess, sys.argv[1])
    outputFile = open("RL_ouput.txt", "w")
    lofFile = open("log.txt", "w")
    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True
    with tf.Session(config=tfconfig) as sess:
        #with tf.device("/gpu:1"):
        # Create model.
        print("Creating %d layers of %d units." %
              (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, dummy_set, False, False)
        if FLAGS.dual_learning:
            du_model = create_model(sess, dummy_set, False, True)
        #sess.run(model.learning_rate_set_op)
        # Read data into buckets and compute their sizes.
        print("Reading development and training data (limit: %d)." %
              FLAGS.max_train_data_size)
        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []

        en_dict_cover = {}
        fr_dict_cover = {}
        if model.global_step.eval() > FLAGS.steps_per_checkpoint:
            try:
                with open(FLAGS.en_cover_dict_path, "rb") as ef:
                    en_dict_cover = pickle.load(ef)
                    # for line in ef.readlines():
                    #     line = line.strip()
                    #     key, value = line.strip(",")
                    #     en_dict_cover[int(key)]=int(value)
            except Exception:
                print("no find query_cover_file")
            try:
                with open(FLAGS.ff_cover_dict_path, "rb") as ff:
                    fr_dict_cover = pickle.load(ff)
                    # for line in ff.readlines():
                    #     line = line.strip()
                    #     key, value = line.strip(",")
                    #     fr_dict_cover[int(key)]=int(value)
            except Exception:
                print("no find answer_cover_file")

        step_loss_summary = tf.Summary()
        #merge = tf.merge_all_summaries()
        writer = tf.summary.FileWriter("./logs/", sess.graph)

        while True:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            for ind in range(30):
                dev_set = read_data(dev_query, dev_answer, 0, 3000000)
                train_set = read_data(train_query, train_answer, ind * 100000,
                                      (ind + 1) * 100000)
                fixed_set = read_data(fixed_query, fixed_answer,
                                      FLAGS.max_train_data_size)
                weibo_set = read_data(weibo_query, weibo_answer,
                                      FLAGS.max_train_data_size)
                qa_set = read_data(qa_query, qa_answer,
                                   FLAGS.max_train_data_size)

                train_bucket_sizes = [
                    len(train_set[b]) for b in xrange(len(_buckets))
                ]
                train_total_size = float(sum(train_bucket_sizes))
                train_buckets_scale = [
                    sum(train_bucket_sizes[:i + 1]) / train_total_size
                    for i in xrange(len(train_bucket_sizes))
                ]
                for kk in range(500):
                    random_number_01 = np.random.random_sample()
                    bucket_id = min([
                        i for i in xrange(len(train_buckets_scale))
                        if train_buckets_scale[i] > random_number_01
                    ])

                    # Get a batch and make a step.
                    start_time = time.time()
                    encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch(
                        train_set, bucket_id, 0, fixed_set, weibo_set, qa_set)
                    inv_encoder_inputs, inv_decoder_inputs, inv_target_weights, inv_batch_source_encoder, inv_batch_source_decoder = model.inverse(
                        batch_source_encoder, batch_source_decoder, bucket_id)
                    if FLAGS.reinforce_learning:
                        if FLAGS.dual_learning:
                            _, step_loss1, _ = model.step_dual(
                                sess,
                                _buckets,
                                encoder_inputs,
                                decoder_inputs,
                                target_weights,
                                batch_source_encoder,
                                batch_source_decoder,
                                bucket_id,
                                du_model,
                                rev_vocab=rev_vocab)
                            _, step_loss2, _ = du_model.step_dual(
                                sess,
                                _buckets,
                                inv_encoder_inputs,
                                inv_decoder_inputs,
                                inv_target_weights,
                                inv_batch_source_encoder,
                                inv_batch_source_decoder,
                                bucket_id,
                                model,
                                rev_vocab=rev_vocab)
                            step_loss = []
                            for ii in range(len(step_loss1)):
                                step_loss.append(step_loss1[ii] +
                                                 step_loss2[ii])
                        else:
                            _, step_loss, _ = model.step_rl(
                                sess,
                                _buckets,
                                encoder_inputs,
                                decoder_inputs,
                                target_weights,
                                batch_source_encoder,
                                batch_source_decoder,
                                bucket_id,
                                rev_vocab=rev_vocab,
                                disSession=disSess,
                                disModel=disModel)
                    else:
                        _, step_loss, _ = model.step(sess,
                                                     encoder_inputs,
                                                     decoder_inputs,
                                                     target_weights,
                                                     bucket_id,
                                                     forward_only=False,
                                                     force_dec_input=True)

                    lossmean = 0.
                    for ii in step_loss:
                        lossmean = lossmean + ii
                    lossmean = lossmean / len(step_loss)
                    loss += lossmean / FLAGS.steps_per_checkpoint
                    step_time += (time.time() -
                                  start_time) / FLAGS.steps_per_checkpoint
                    current_step += 1

                    query_size, answer_size = _buckets[bucket_id]
                    for batch_index in xrange(FLAGS.batch_size):
                        for query_index in xrange(query_size):
                            query_word = encoder_inputs[query_index][
                                batch_index]
                            if en_dict_cover.has_key(query_word):
                                en_dict_cover[query_word] += 1
                            else:
                                en_dict_cover[query_word] = 0

                        for answer_index in xrange(answer_size):
                            answer_word = decoder_inputs[answer_index][
                                batch_index]
                            if fr_dict_cover.has_key(answer_word):
                                fr_dict_cover[answer_word] += 1
                            else:
                                fr_dict_cover[answer_word] = 0

                    # Once in a while, we save checkpoint, print statistics, and run evals.
                    if current_step % FLAGS.steps_per_checkpoint == 0:
                        outputFile = open(
                            "OpenSubData/RL_" + str(model.global_step.eval()) +
                            ".txt", "w")
                        bucket_value = step_loss_summary.value.add()
                        bucket_value.tag = "loss"
                        bucket_value.simple_value = float(loss)
                        writer.add_summary(step_loss_summary, current_step)

                        print("query_dict_cover_num: %s" %
                              (str(en_dict_cover.__len__())))
                        print("answer_dict_cover_num: %s" %
                              (str(fr_dict_cover.__len__())))

                        ef = open(FLAGS.en_cover_dict_path, "wb")
                        pickle.dump(en_dict_cover, ef)
                        ff = open(FLAGS.ff_cover_dict_path, "wb")
                        pickle.dump(fr_dict_cover, ff)
                        num = 0
                        pick = 0.
                        mmm = 1
                        eval_loss = 0
                        dictt = {}
                        dictt_b = {}
                        for idd in range(2):
                            bucket_id = idd + 2
                            batch_num = 1 + int(
                                len(dev_set[bucket_id]) / FLAGS.batch_size)
                            for mm in range(batch_num):
                                encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch_dev(
                                    dev_set, bucket_id, mm * FLAGS.batch_size,
                                    fixed_set, weibo_set, qa_set)
                                _, eval_loss_per, output_logits = model.step(
                                    sess,
                                    encoder_inputs,
                                    decoder_inputs,
                                    target_weights,
                                    bucket_id,
                                    forward_only=True,
                                    force_dec_input=False)
                                #_, eval_loss_per, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=False, force_dec_input=True)
                                eval_loss += np.mean(eval_loss_per)
                                resp_tokens = model.remove_type(
                                    output_logits,
                                    model.buckets[bucket_id],
                                    type=1)
                                #prob = model.calprob(sess,_buckets, encoder_inputs, decoder_inputs, target_weights,batch_source_encoder, batch_source_decoder, bucket_id,rev_vocab=rev_vocab)
                                resp_c = model.ids2tokens(
                                    resp_tokens, rev_vocab)
                                resp_b = model.ids2tokens(
                                    batch_source_decoder, rev_vocab)
                                resp_a = model.ids2tokens(
                                    batch_source_encoder, rev_vocab)
                                for ii in range(len(resp_a)):
                                    aa = ""
                                    for ww in resp_a[ii]:
                                        aa = aa + " " + ww
                                    bb = ""
                                    for ww in resp_b[ii]:
                                        bb = bb + " " + ww
                                    cc = ""
                                    pre = ""
                                    for ww in resp_c[ii]:
                                        cc = cc + " " + ww
                                        if ww not in dictt:
                                            dictt[ww] = 0
                                        if pre + ww not in dictt_b:
                                            dictt_b[pre + ww] = 0
                                        dictt[ww] += 1
                                        dictt_b[pre + ww] += 1
                                        pre = ww
                                    #print("Q:",aa)
                                    #print("A1:",bb)
                                    #print("A2:",cc)
                                    #print("\n")
                                    outputFile.write("%s\n%s\n%s \n\n" %
                                                     (aa, bb, cc))
                                    outputFile.flush()
                                    BLEUscore = nltk.translate.bleu_score.sentence_bleu(
                                        [resp_c[ii]], resp_b[ii])
                                    print(BLEUscore)
                                    #eval_loss += BLEUscore
                                mmm += 1
                                #dummy = model.caldummy(sess,_buckets, encoder_inputs, decoder_inputs, target_weights,batch_source_encoder, batch_source_decoder, bucket_id,rev_vocab=rev_vocab)
                                #print(dummy)
                                #eval_loss +=dummy
                        eval_loss = eval_loss / mmm

                        # Print statistics for the previous epoch.
                        perplexity = math.exp(loss) if loss < 300 else float(
                            'inf')
                        print(
                            "global step %d learning rate %.4f step-time %.2f loss "
                            "%.2f" %
                            (model.global_step.eval(),
                             model.learning_rate.eval(), step_time, loss))
                        # Decrease learning rate if no improvement was seen over last 3 times.
                        if len(previous_losses) > 2 and loss > max(
                                previous_losses[-3:]):
                            sess.run(model.learning_rate_decay_op)
                            sess.run(du_model.learning_rate_decay_op)
                        previous_losses.append(loss)
                        # Save checkpoint and zero timer and loss.
                        checkpoint_path = os.path.join(FLAGS.train_dir,
                                                       "weibo.model")
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=model.global_step)
                        checkpoint_path2 = os.path.join(
                            FLAGS.train_dir2, "weibo.du_model")
                        du_model.saver.save(sess,
                                            checkpoint_path2,
                                            global_step=model.global_step)

                        eval_ppx = math.exp(
                            eval_loss) if eval_loss < 300 else float('inf')
                        summ = [dictt[w] for w in dictt]
                        summ = 1.0 * sum(summ)
                        print(
                            "  eval: %.5f  bucket %d distinct-1 %.5f  distinct-2  %.5f "
                            % (eval_loss, bucket_id, len(dictt) / summ,
                               len(dictt_b) / summ))
                        lofFile.write("%.2f   %.2f\n" % (loss, eval_loss))
                        lofFile.flush()
                        step_time, loss = 0.0, 0.0
                        # Run evals on development set and print their perplexity.
                        # for bucket_id in xrange(len(_buckets)):
                        #   encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                        #       dev_set, bucket_id)
                        #   _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                        #                                target_weights, bucket_id, True)
                        #   eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
                        #   print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
                        sys.stdout.flush()
Exemple #3
0
def test_decoder(config):
    train_path = os.path.join(config.train_dir, "movie_subtitle.train")
    data_path_list = [train_path + ".answer", train_path + ".query"]
    vocab_path = os.path.join(config.train_dir,
                              "vocab%d.all" % config.vocab_size)
    #    data_utils.create_vocabulary(vocab_path, data_path_list, config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    dummy_set = data_utils.get_dummy_set("grl_data/dummy_sentence", vocab,
                                         25000)
    with tf.Session() as sess:
        if config.name_model in [
                gst_config.name_model, gcc_config.name_model,
                gbk_config.name_model
        ]:
            model = create_st_model(sess,
                                    config,
                                    forward_only=True,
                                    name_scope=config.name_model)

        elif config.name_model in [
                grl_config.name_model, pre_grl_config.name_model
        ]:
            model = create_rl_model(sess,
                                    config,
                                    forward_only=True,
                                    name_scope=config.name_model,
                                    dummy_set=dummy_set)

        model.batch_size = 1

        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            token_ids = data_utils.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), vocab)
            print("token_id: ", token_ids)
            bucket_id = len(config.buckets) - 1
            for i, bucket in enumerate(config.buckets):
                if bucket[0] >= len(token_ids):
                    bucket_id = i
                    break
            else:
                print("Sentence truncated: %s", sentence)

            encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
                {bucket_id: [(token_ids, [1])]}, bucket_id)
            # st_model step
            if config.name_model in [
                    gst_config.name_model, gcc_config.name_model,
                    gbk_config.name_model
            ]:
                output_logits, _ = model.step(sess, encoder_inputs,
                                              decoder_inputs, target_weights,
                                              bucket_id, True)
                outputs = [
                    int(np.argmax(logit, axis=1)) for logit in output_logits
                ]
                if data_utils.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                print(" ".join([str(rev_vocab[output]) for output in outputs]))

            # beam_search step
            elif config.name_model in [
                    grl_config.name_model, pre_grl_config.name_model
            ]:
                _, _, output_logits = model.step(sess,
                                                 encoder_inputs,
                                                 decoder_inputs,
                                                 target_weights,
                                                 reward=1,
                                                 bucket_id=bucket_id,
                                                 forward_only=True)
                #output_logits = np.reshape(output_logits,[1,-1,25000])
                output_logits = np.squeeze(output_logits)
                outputs = np.argmax(output_logits, axis=1)
                outputs = list(outputs)
                # for i, output in enumerate(output_logits):
                #     print("index: %d, answer tokens: %s" %(i, str(output)))
                #     if data_utils.EOS_ID in output:
                #         output = output[:output.index(data_utils.EOS_ID)]
                if data_utils.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                print(outputs)
                while data_utils.UNK_ID in outputs:
                    sub_max = np.argmax(output_logits[outputs.index(
                        data_utils.UNK_ID)][4:]) + 4
                    outputs[outputs.index(data_utils.UNK_ID)] = sub_max
                print(" ".join([str(rev_vocab[out]) for out in outputs]))

            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Exemple #4
0
def read_file_test(config, test_model_name, input_path, output_path):
    vocab_path = os.path.join(config.train_dir,
                              "vocab%d.all" % config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    dummy_set = data_utils.get_dummy_set("grl_data/dummy_sentence", vocab,
                                         25000)
    forward_only = True
    with tf.Session() as sess:
        with tf.variable_scope(name_or_scope=config.name_model):
            model = grl_rnn_model.grl_model(grl_config=config,
                                            name_scope=config.name_model,
                                            forward=forward_only,
                                            dummy_set=dummy_set)
            #ckpt = tf.train.get_checkpoint_state(os.path.join(rl_config.train_dir, "checkpoints"))
            #print (ckpt.model_checkpoint_path)
            model.batch_size = 1
            if test_model_name == 'S2S':
                model.saver.restore(sess,
                                    "grl_data/movie_subtitle.model-118000")
            elif test_model_name == 'RL':
                model.saver.restore(sess,
                                    "grl_data/movie_subtitle.model-127200")
            else:
                model.saver.restore(sess,
                                    "grl_data/movie_subtitle.model-127200")
            with open(input_path) as f:
                sentences = f.readlines()
            output_file = []
            for sentence in sentences:
                token_ids = data_utils.sentence_to_token_ids(
                    tf.compat.as_bytes(sentence), vocab)
                print("token_id: ", token_ids)
                bucket_id = len(config.buckets) - 1
                for i, bucket in enumerate(config.buckets):
                    if bucket[0] >= len(token_ids):
                        bucket_id = i
                        break
                else:
                    print("Sentence truncated: %s", sentence)

                encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
                    {bucket_id: [(token_ids, [1])]}, bucket_id)
                _, _, output_logits = model.step(sess,
                                                 encoder_inputs,
                                                 decoder_inputs,
                                                 target_weights,
                                                 reward=1,
                                                 bucket_id=bucket_id,
                                                 forward_only=True)
                #output_logits = np.reshape(output_logits,[1,-1,25000])
                output_logits = np.squeeze(output_logits)
                outputs = np.argmax(output_logits, axis=1)
                outputs = list(outputs)

                # for i, output in enumerate(output_logits):
                #     print("index: %d, answer tokens: %s" %(i, str(output)))
                #     if data_utils.EOS_ID in output:
                #         output = output[:output.index(data_utils.EOS_ID)]
                if data_utils.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                print(outputs)
                while data_utils.UNK_ID in outputs:
                    sub_max = np.argmax(output_logits[outputs.index(
                        data_utils.UNK_ID)][4:]) + 4
                    outputs[outputs.index(data_utils.UNK_ID)] = sub_max
                while 30 in outputs:
                    outputs.remove(30)

                output_sentence = " ".join(
                    [str(rev_vocab[out]) for out in outputs]) + '\n'
                #while '$' in output_sentence:
                output_sentence = output_sentence.replace("$", "")
                print(output_sentence)
                output_file.append(output_sentence)
            f = open(output_path, 'w')
            f.writelines(output_file)
Exemple #5
0
def train():
    vocab, rev_vocab, train_set = prepare_data(grl_config)
    for b_set in train_set:
        print("b_set length: ", len(b_set))
    dummy_set = data_utils.get_dummy_set("grl_data/dummy_sentence", vocab,
                                         25000)
    with tf.Session() as sess:
        rl_model = create_rl_model(sess, grl_config, False,
                                   grl_config.name_model, dummy_set)
        st_model = create_st_model(sess, gst_config, True,
                                   gst_config.name_model)
        #bk_model = create_st_model(sess, gbk_config, True, gbk_config.name_model)
        #cc_model = create_st_model(sess, gcc_config, True, gcc_config.name_model)

        train_bucket_sizes = [
            len(train_set[b]) for b in range(len(grl_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in range(len(train_bucket_sizes))
        ]

        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        step_loss_summary = tf.Summary()
        # merge = tf.merge_all_summaries()
        rl_writer = tf.summary.FileWriter(grl_config.tensorboard_dir,
                                          sess.graph)
        while True:
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, _ = \
                rl_model.get_batch(train_set,bucket_id)

            updata, norm, step_loss = rl_model.step_rl(
                sess,
                st_model=st_model,
                bk_model=st_model,
                cc_model=st_model,
                encoder_inputs=encoder_inputs,
                decoder_inputs=decoder_inputs,
                target_weights=target_weights,
                batch_source_encoder=batch_source_encoder,
                bucket_id=bucket_id)

            step_time += (time.time() -
                          start_time) / grl_config.steps_per_checkpoint
            loss += step_loss / grl_config.steps_per_checkpoint
            current_step += 1

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % grl_config.steps_per_checkpoint == 0:

                bucket_value = step_loss_summary.value.add()
                bucket_value.tag = grl_config.name_loss
                bucket_value.simple_value = float(loss)
                rl_writer.add_summary(step_loss_summary,
                                      int(sess.run(rl_model.global_step)))

                # Print statistics for the previous epoch.
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                print(
                    "global step %d learning rate %.4f step-time %.2f perplexity "
                    "%.2f" %
                    (rl_model.global_step.eval(),
                     rl_model.learning_rate.eval(), step_time, perplexity))
                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(
                        previous_losses[-3:]):
                    sess.run(rl_model.learning_rate_decay_op)
                previous_losses.append(loss)
                # Save checkpoint and zero timer and loss.
                gen_ckpt_dir = os.path.abspath(
                    os.path.join(grl_config.train_dir, "checkpoints"))
                if not os.path.exists(gen_ckpt_dir):
                    os.makedirs(gen_ckpt_dir)
                checkpoint_path = os.path.join(gen_ckpt_dir,
                                               "movie_subtitle.model")
                rl_model.saver.save(sess,
                                    checkpoint_path,
                                    global_step=rl_model.global_step)
                step_time, loss = 0.0, 0.0
                # Run evals on development set and print their perplexity.
                # for bucket_id in xrange(len(gen_config.buckets)):
                #   encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                #       dev_set, bucket_id)
                #   _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                #                                target_weights, bucket_id, True)
                #   eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
                #   print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
                sys.stdout.flush()