Ejemplo n.º 1
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'vocab_size', 'dataset', 'mode', 'lr', 'adagrad_init_acc',
        'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
        'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num',
        'max_enc_num', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'srl_max_dec_seq_len', 'srl_max_dec_sen_num', 'srl_max_enc_seq_len',
        'srl_max_enc_sen_num'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_srl_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'sc_max_dec_seq_len', 'sc_max_enc_seq_len'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_sc_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # Create a batcher object that will create minibatches of data

    sc_batcher = Sc_GenBatcher(vocab, hps_sc_generator)

    tf.set_random_seed(111)  # a seed value for randomness

    if hps_generator.mode == 'train':

        print("Start pre-training......")
        sc_model = Sc_Generator(hps_sc_generator, vocab)

        sess_sc, saver_sc, train_dir_sc = setup_training_sc_generator(sc_model)
        sc_generated = Generated_sc_sample(sc_model, vocab, sess_sc)
        print("Start pre-training generator......")
        run_pre_train_sc_generator(sc_model, sc_batcher, 40, sess_sc, saver_sc,
                                   train_dir_sc, sc_generated)

        if not os.path.exists("data/" + str(0) + "/"):
            os.mkdir("data/" + str(0) + "/")
        sc_generated.generator_max_example_test(
            sc_batcher.get_batches("pre-train"),
            "data/" + str(0) + "/train_skeleton.txt")

        sc_generated.generator_max_example_test(
            sc_batcher.get_batches("pre-valid"),
            "data/" + str(0) + "/valid_skeleton.txt")

        sc_generated.generator_max_example_test(
            sc_batcher.get_batches("pre-test"),
            "data/" + str(0) + "/test_skeleton.txt")

        merge("data/story/train_process.txt", "data/0/train_skeleton.txt",
              "data/0/train.txt")
        merge("data/story/validation_process.txt", "data/0/valid_skeleton.txt",
              "data/0/valid.txt")
        merge("data/story/test_process.txt", "data/0/test_skeleton.txt",
              "data/0/test.txt")

        #################################################################################################
        batcher = GenBatcher(vocab, hps_generator)
        srl_batcher = Srl_GenBatcher(vocab, hps_srl_generator)
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
        generated = Generated_sample(model, vocab, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(model, batcher, 30, sess_ge, saver_ge,
                                train_dir_ge, generated)
        ##################################################################################################
        srl_generator_model = Srl_Generator(hps_srl_generator, vocab)

        sess_srl_ge, saver_srl_ge, train_dir_srl_ge = setup_training_srl_generator(
            srl_generator_model)
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        util.load_ckpt(saver_sc, sess_sc, ckpt_dir="train-sc-generator")
        srl_generated = Generated_srl_sample(srl_generator_model, vocab,
                                             sess_srl_ge)
        whole_generated = Generated_whole_sample(model, srl_generator_model,
                                                 vocab, sess_ge, sess_srl_ge,
                                                 batcher, srl_batcher)
        print("Start pre-training srl_generator......")
        run_pre_train_srl_generator(srl_generator_model, batcher, srl_batcher,
                                    20, sess_srl_ge, saver_srl_ge,
                                    train_dir_srl_ge, srl_generated,
                                    whole_generated)

        loss_window = 0
        t0 = time.time()
        print("begin reinforcement learning:")
        for epoch in range(10):

            loss_window = 0.0

            batcher = GenBatcher(vocab, hps_generator)
            srl_batcher = Srl_GenBatcher(vocab, hps_srl_generator)

            batches = batcher.get_batches(mode='train')
            srl_batches = srl_batcher.get_batches(mode='train')
            sc_batches = sc_batcher.get_batches(mode='train')
            len_sc = len(sc_batches)

            for i in range(min(len(batches), len(srl_batches))):
                current_batch = batches[i]
                current_srl_batch = srl_batches[i]
                current_sc_batch = sc_batches[i % (len_sc - 1)]

                results = model.run_pre_train_step(sess_ge, current_batch)
                loss_list = results['without_average_loss']

                example_skeleton_list = current_batch.original_review_outputs
                example_text_list = current_batch.original_target_sentences

                new_batch = sc_batcher.get_text_queue(example_skeleton_list,
                                                      example_text_list,
                                                      loss_list)
                results_sc = sc_model.run_rl_train_step(sess_sc, new_batch)
                loss = results_sc['loss']
                loss_window += loss

                results_srl = srl_generator_model.run_pre_train_step(
                    sess_srl_ge, current_srl_batch)
                loss_list_srl = results_srl['without_average_loss']

                example_srl_text_list = current_srl_batch.orig_outputs
                example_skeleton_srl_list = current_srl_batch.orig_inputs

                new_batch = sc_batcher.get_text_queue(
                    example_skeleton_srl_list, example_srl_text_list,
                    loss_list_srl)
                results_sc = sc_model.run_rl_train_step(sess_sc, new_batch)
                loss = results_sc['loss']
                loss_window += loss

                results_sc = sc_model.run_rl_train_step(
                    sess_sc, current_sc_batch)
                loss = results_sc['loss']
                loss_window += loss

                train_step = results['global_step']

                if train_step % 100 == 0:
                    t1 = time.time()
                    tf.logging.info(
                        'seconds for %d training generator step: %.3f ',
                        train_step, (t1 - t0) / 300)
                    t0 = time.time()
                    tf.logging.info('loss: %f', loss_window /
                                    100)  # print the loss to screen
                    loss_window = 0.0

                train_srl_step = results_srl['global_step']

                if train_srl_step % 10000 == 0:
                    saver_sc.save(sess_sc,
                                  train_dir_sc + "/model",
                                  global_step=results_sc['global_step'])
                    saver_ge.save(sess_ge,
                                  train_dir_ge + "/model",
                                  global_step=train_step)
                    saver_srl_ge.save(sess_srl_ge,
                                      train_dir_srl_ge + "/model",
                                      global_step=train_srl_step)

                    srl_generated.generator_max_example(
                        srl_batcher.get_batches("validation"),
                        "to_seq_max_generated/valid/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "to_seq_max_generated/valid/" +
                        str(int(train_srl_step / 30000)) + "_negative")
                    srl_generated.generator_max_example(
                        srl_batcher.get_batches("test"),
                        "to_seq_max_generated/test/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "to_seq_max_generated/test/" +
                        str(int(train_srl_step / 30000)) + "_negative")

                    whole_generated.generator_max_example(
                        batcher.get_batches("test-validation"),
                        "max_generated_final/valid/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "max_generated_final/valid/" +
                        str(int(train_srl_step / 30000)) + "_negative")
                    whole_generated.generator_max_example(
                        batcher.get_batches("test-test"),
                        "max_generated_final/test/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "max_generated_final/test/" +
                        str(int(train_srl_step / 30000)) + "_negative")

            sc_generated.generator_max_example_test(
                sc_batcher.get_batches("pre-train"),
                "data/" + str(0) + "/train_skeleton.txt")

            sc_generated.generator_max_example_test(
                sc_batcher.get_batches("pre-valid"),
                "data/" + str(0) + "/valid_skeleton.txt")

            sc_generated.generator_max_example_test(
                sc_batcher.get_batches("pre-test"),
                "data/" + str(0) + "/test_skeleton.txt")

            merge("data/story/train_process.txt", "data/0/train_skeleton.txt",
                  "data/0/train.txt")
            merge("data/story/validation_process.txt",
                  "data/0/valid_skeleton.txt", "data/0/valid.txt")
            merge("data/story/test_process.txt", "data/0/test_skeleton.txt",
                  "data/0/test.txt")

    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
Ejemplo n.º 2
0
def main(unused_argv):
    # %%
    # choose what level of logging you want
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))
    # 創建字典
    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)

    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():
        if key in hparam_list:
            hps_dict[key] = val.value  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'max_enc_sen_num', 'max_enc_seq_len'
    ]
    hps_dict = {}

    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:
            hps_dict[key] = val.value  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # # 取出最小batch size 的資料量
    batcher = GenBatcher(vocab, hps_generator)
    # print(batcher.train_batch[0].original_review_inputs)
    # print(len(batcher.train_batch[0].original_review_inputs))
    tf.set_random_seed(123)
    # %%
    if FLAGS.mode == 'train_generator':

        # print("Start pre-training ......")
        ge_model = Generator(hps_generator, vocab)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model)

        generated = Generated_sample(ge_model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(ge_model, batcher, 300, sess_ge, saver_ge,
                                train_dir_ge)
        # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        print("finish load train-generator")

        print("Generating negative examples......")
        generator_graph = tf.Graph()
        with generator_graph.as_default():
            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
            print("finish load train-generator")

        generated.generator_train_negative_example()
        generated.generator_test_negative_example()

        print("finish write")
    elif FLAGS.mode == 'train_discriminator':
        # print("Start pre-training ......")
        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)

        print("Start pre-training discriminator......")
        if not os.path.exists("discriminator_result"):
            os.mkdir("discriminator_result")
        run_pre_train_discriminator(model_dis, dis_batcher, 1000, sess_dis,
                                    saver_dis, train_dir_dis)

    elif FLAGS.mode == "adversarial_train":

        generator_graph = tf.Graph()
        discriminatorr_graph = tf.Graph()

        print("Start adversarial-training......")
        # tf.reset_default_graph()

        with generator_graph.as_default():
            model = Generator(hps_generator, vocab)
            sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
            generated = Generated_sample(model, vocab, batcher, sess_ge)

            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
            print("finish load train-generator")
        with discriminatorr_graph.as_default():
            model_dis = Discriminator(hps_discriminator, vocab)
            dis_batcher = DisBatcher(hps_discriminator, vocab,
                                     "discriminator_train/positive/*",
                                     "discriminator_train/negative/*",
                                     "discriminator_test/positive/*",
                                     "discriminator_test/negative/*")
            sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
                model_dis)

            util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")
            print("finish load train-discriminator")

        print("Start adversarial  training......")
        if not os.path.exists("train_sample_generated"):
            os.mkdir("train_sample_generated")
        if not os.path.exists("test_max_generated"):
            os.mkdir("test_max_generated")
        if not os.path.exists("test_sample_generated"):
            os.mkdir("test_sample_generated")

        whole_decay = False

        for epoch in range(100):
            print('開始訓練')
            batches = batcher.get_batches(mode='train')
            for step in range(int(len(batches) / 14)):

                run_train_generator(model, model_dis, sess_dis, batcher,
                                    dis_batcher,
                                    batches[step * 14:(step + 1) * 14],
                                    sess_ge, saver_ge, train_dir_ge)
                generated.generator_sample_example(
                    "train_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "train_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    14)

                tf.logging.info("test performance: ")
                tf.logging.info("epoch: " + str(epoch) + " step: " + str(step))

                #                print("evaluate the diversity of DP-GAN (decode based on  max probability)")
                #                generated.generator_test_sample_example(
                #                    "test_sample_generated/" +
                #                    str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                #                    "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14)
                #
                #                print("evaluate the diversity of DP-GAN (decode based on sampling)")
                #                generated.generator_test_max_example(
                #                    "test_max_generated/" +
                #                    str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                #                    "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14)

                dis_batcher.train_queue = []
                for i in range(epoch + 1):
                    for j in range(step + 1):
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_positive/*")
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_negative/*")
                dis_batcher.train_batch = dis_batcher.create_batches(
                    mode="train", shuffleis=True)
                whole_decay = run_train_discriminator(
                    model_dis, 5, dis_batcher,
                    dis_batcher.get_batches(mode="train"), sess_dis, saver_dis,
                    train_dir_dis, whole_decay)
    elif FLAGS.mode == "test_language_model":
        ge_model = Generator(hps_generator, vocab)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model)
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        print("finish load train-generator")

        #        generator_graph = tf.Graph()
        #        with generator_graph.as_default():
        #            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        #            print("finish load train-generator")

        # jieba.load_userdict('dir.txt')
        inputs = ''
        while inputs != "close":

            inputs = input("Enter your ask: ")
            sentence = segmentor.segment(t2s.convert(inputs))
            #            sentence = jieba.cut(inputs)
            sentence = (" ".join(sentence))
            sentence = s2t.convert(sentence)
            print(sentence)
            sentence = sentence.split()

            enc_input = [vocab.word2id(w) for w in sentence]
            enc_lens = np.array([len(enc_input)])
            enc_input = np.array([enc_input])
            out_sentence = ('[START]').split()
            dec_batch = [vocab.word2id(w) for w in out_sentence]
            #dec_batch = [2] + dec_batch
            # dec_batch.append(3)
            while len(dec_batch) < 40:
                dec_batch.append(1)

            dec_batch = np.array([dec_batch])
            dec_batch = np.resize(dec_batch, (1, 1, 40))
            dec_lens = np.array([len(dec_batch)])
            if (FLAGS.beamsearch == 'beamsearch_train'):
                result = ge_model.run_test_language_model(
                    sess_ge, enc_input, enc_lens, dec_batch, dec_lens)
                #                print(result['generated'])
                #                print(result['generated'].shape)
                output_ids = result['generated'][0]
                decoded_words = data.outputids2words(output_ids, vocab, None)
                print("decoded_words :", decoded_words)
            else:
                results = ge_model.run_test_beamsearch_example(
                    sess_ge, enc_input, enc_lens, dec_batch, dec_lens)
                beamsearch_outputs = results['beamsearch_outputs']
                for i in range(5):
                    predict_list = np.ndarray.tolist(beamsearch_outputs[:, :,
                                                                        i])
                    predict_list = predict_list[0]
                    predict_seq = [vocab.id2word(idx) for idx in predict_list]
                    decoded_words = " ".join(predict_seq).split()
                    #                    decoded_words = decoded_words

                    try:
                        if decoded_words[0] == '[STOPDOC]':
                            decoded_words = decoded_words[1:]
                        # index of the (first) [STOP] symbol
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if decoded_words[-1] != '.' and decoded_words[
                            -1] != '!' and decoded_words[-1] != '?':
                        decoded_words.append('.')
                    decoded_words_all = []
                    decoded_output = ' '.join(
                        decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                    decoded_words_all = ' '.join(decoded_words_all).strip()
                    decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                    decoded_words_all = decoded_words_all.replace("[UNK]", "")
                    decoded_words_all = decoded_words_all.replace(" ", "")
                    decoded_words_all, _ = re.subn(r"(! ){2,}", "",
                                                   decoded_words_all)
                    decoded_words_all, _ = re.subn(r"(\. ){2,}", "",
                                                   decoded_words_all)
                    if decoded_words_all.startswith(','):
                        decoded_words_all = decoded_words_all[1:]
                    print("The resonse   : {}".format(decoded_words_all))
Ejemplo n.º 3
0
def main():
    ################################
    ## 第一模块:数据准备工作
    data_ = data.Data(args.data_dir, args.vocab_size)

    # 对ICD tree 处理
    parient_children, level2_parients, leafNodes, adj, node2id, hier_dicts = utils.build_tree(
        os.path.join(args.data_dir, 'note_labeled.csv'))
    graph = utils.generate_graph(parient_children, node2id)
    args.node2id = node2id
    args.adj = torch.Tensor(adj).long().to(args.device)
    args.leafNodes = leafNodes
    args.hier_dicts = hier_dicts

    # TODO batcher对象的细节
    g_batcher = GenBatcher(data_, args)

    #################################
    ## 第二模块: 创建G模型,并预训练 G模型
    # TODO Generator对象的细节
    gen_model = Generator(args, data_, graph, level2_parients)

    gen_model.to(args.device)
    # TODO generated 对象的细节
    generated = Generated_example(gen_model, data_, g_batcher)
    # 预训练 G模型
    pre_train_generator(gen_model, g_batcher, 10)

    # 利用G 生成一些negative samples
    generated.generator_train_negative_samples()
    generated.generator_test_negative_samples()

    #####################################
    ## 第三模块: 创建 D模型,并预训练 D模型
    d_model = Discriminator(args, data_)

    d_batcher = DisBatcher(data_, args)

    # 预训练 D模型
    pre_train_discriminator(d_model, d_batcher, 25)

    ########################################
    ## 第四模块: 交替训练G和D模型
    for epoch in range(args.num_epochs):
        batches = g_batcher.get_batches(mode='train')
        for step in range(int(len(batches) / 1000)):

            #训练 G模型
            train_generator(gen_model, d_model, g_batcher, d_batcher,
                            batches[step * 1000:(step + 1) * 1000], generated)

            # 生成训练D的negative samples
            generated.generator_samples(
                "train_sample_generated/" + str(epoch) + "epoch_step" +
                str(step) + "_temp_positive", "train_sample_generated/" +
                str(epoch) + "epoch_step" + str(step) + "_temp_negative", 1000)

            # 生成测试样本
            generated.generator_test_samples()

            # TODO: 评估 G模型的表现

            # 创建训练D的batch(即包含 negative samples和positive samples)
            d_batcher.train_batch = d_batcher.create_batches(mode='train',
                                                             shuffleis=True)

            # 训练 D网络
            train_discriminator(d_model, 5, d_batcher,
                                dis_batcher.get_batches(mode="train"))
Ejemplo n.º 4
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    tf.set_random_seed(
        111
    )  # a seed value for randomness # train-classification  train-sentiment  train-cnn-classificatin train-generator

    if FLAGS.mode == "train-classifier":

        #print("Start pre-training......")
        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)
        print("Start pre-training classification......")
        run_pre_train_classification(model_class, cla_batcher, 1, sess_cls,
                                     saver_cls, train_dir_cls)  #10
        generated = Generate_training_sample(model_class, vocab, cla_batcher,
                                             sess_cls)

        print("Generating training examples......")
        generated.generate_training_example("train")
        generated.generate_test_example("test")

    elif FLAGS.mode == "train-sentimentor":

        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)

        print("Start pre_train_sentimentor......")
        model_sentiment = Sentimentor(hps_generator, vocab)
        sentiment_batcher = SenBatcher(hps_generator, vocab)
        sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor(
            model_sentiment)
        util.load_ckpt(saver_cls, sess_cls, ckpt_dir="train-classification")
        run_pre_train_sentimentor(model_sentiment, sentiment_batcher, 1,
                                  sess_sen, saver_sen, train_dir_sen)  #1

    elif FLAGS.mode == "test":

        config = {
            'n_epochs': 5,
            'kernel_sizes': [3, 4, 5],
            'dropout_rate': 0.5,
            'val_split': 0.4,
            'edim': 300,
            'n_words': None,  # Leave as none
            'std_dev': 0.05,
            'sentence_len': 50,
            'n_filters': 100,
            'batch_size': 50
        }
        config['n_words'] = 50000

        cla_cnn_batcher = CNN_ClaBatcher(hps_discriminator, vocab)
        cnn_classifier = CNN(config)
        sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls = setup_training_cnnclassifier(
            cnn_classifier)
        #util.load_ckpt(saver_cnn_cls, sess_cnn_cls, ckpt_dir="train-cnnclassification")
        run_train_cnn_classifier(cnn_classifier, cla_cnn_batcher, 1,
                                 sess_cnn_cls, saver_cnn_cls,
                                 train_dir_cnn_cls)  #1

        files = os.listdir("test-generate-transfer/")
        for file_ in files:
            run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls,
                                "test-generate-transfer/" + file_ + "/*")

    #elif FLAGS.mode == "test":

    elif FLAGS.mode == "train-generator":

        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)

        model_sentiment = Sentimentor(hps_generator, vocab)
        sentiment_batcher = SenBatcher(hps_generator, vocab)
        sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor(
            model_sentiment)

        config = {
            'n_epochs': 5,
            'kernel_sizes': [3, 4, 5],
            'dropout_rate': 0.5,
            'val_split': 0.4,
            'edim': 300,
            'n_words': None,  # Leave as none
            'std_dev': 0.05,
            'sentence_len': 50,
            'n_filters': 100,
            'batch_size': 50
        }
        config['n_words'] = 50000

        cla_cnn_batcher = CNN_ClaBatcher(hps_discriminator, vocab)
        cnn_classifier = CNN(config)
        sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls = setup_training_cnnclassifier(
            cnn_classifier)

        model = Generator(hps_generator, vocab)
        batcher = GenBatcher(vocab, hps_generator)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)

        util.load_ckpt(saver_cnn_cls,
                       sess_cnn_cls,
                       ckpt_dir="train-cnnclassification")
        util.load_ckpt(saver_sen, sess_sen, ckpt_dir="train-sentimentor")

        generated = Generated_sample(model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(model, batcher, 1, sess_ge, saver_ge,
                                train_dir_ge, generated, cla_cnn_batcher,
                                cnn_classifier, sess_cnn_cls)  # 4

        generated.generate_test_negetive_example(
            "temp_negetive",
            batcher)  # batcher, model_class, sess_cls, cla_batcher
        generated.generate_test_positive_example("temp_positive", batcher)

        #run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls,
        #                    "temp_negetive" + "/*")

        loss_window = 0
        t0 = time.time()
        print("begin reinforcement learning:")
        for epoch in range(30):
            batches = batcher.get_batches(mode='train')
            for i in range(len(batches)):
                current_batch = copy.deepcopy(batches[i])
                sentiment_batch = batch_sentiment_batch(
                    current_batch, sentiment_batcher)
                result = model_sentiment.max_generator(sess_sen,
                                                       sentiment_batch)
                weight = result['generated']
                current_batch.weight = weight
                sentiment_batch.weight = weight

                cla_batch = batch_classification_batch(current_batch, batcher,
                                                       cla_batcher)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)

                cc = SmoothingFunction()

                reward_sentiment = 1 - np.abs(0.5 - result['y_pred_auc'])
                reward_BLEU = []
                for k in range(FLAGS.batch_size):
                    reward_BLEU.append(
                        sentence_bleu(
                            [current_batch.original_reviews[k].split()],
                            cla_batch.original_reviews[k].split(),
                            smoothing_function=cc.method1))

                reward_BLEU = np.array(reward_BLEU)

                reward_de = (2 / (1.0 / (1e-6 + reward_sentiment) + 1.0 /
                                  (1e-6 + reward_BLEU)))

                result = model.run_train_step(sess_ge, current_batch)
                train_step = result[
                    'global_step']  # we need this to update our running average loss
                loss = result['loss']
                loss_window += loss
                if train_step % 100 == 0:
                    t1 = time.time()
                    tf.logging.info(
                        'seconds for %d training generator step: %.3f ',
                        train_step, (t1 - t0) / 100)
                    t0 = time.time()
                    tf.logging.info('loss: %f', loss_window /
                                    100)  # print the loss to screen
                    loss_window = 0.0
                if train_step % 10000 == 0:

                    generated.generate_test_negetive_example(
                        "test-generate-transfer/" + str(epoch) + "epoch_step" +
                        str(train_step) + "_temp_positive", batcher)
                    generated.generate_test_positive_example(
                        "test-generate/" + str(epoch) + "epoch_step" +
                        str(train_step) + "_temp_positive", batcher)
                    #saver_ge.save(sess, train_dir + "/model", global_step=train_step)
                    #run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls,
                    #                    "test-generate-transfer/" + str(epoch) + "epoch_step" + str(
                    #                        train_step) + "_temp_positive" + "/*")

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_sentiment = result['y_pred_auc']
                reward_result_bleu = np.array(bleu)

                reward_result = (2 / (1.0 /
                                      (1e-6 + reward_result_sentiment) + 1.0 /
                                      (1e-6 + reward_result_bleu)))

                current_batch.score = 1 - current_batch.score

                result = model.max_generator(sess_ge, current_batch)

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_transfer_sentiment = result['y_pred_auc']
                reward_result_transfer_bleu = np.array(bleu)

                reward_result_transfer = (
                    2 / (1.0 /
                         (1e-6 + reward_result_transfer_sentiment) + 1.0 /
                         (1e-6 + reward_result_transfer_bleu)))

                #tf.logging.info("reward_nonsentiment: "+str(reward_sentiment) +" output_original_sentiment: "+str(reward_result_sentiment)+" output_original_bleu: "+str(reward_result_bleu))

                reward = reward_result_transfer  #reward_de + reward_result_sentiment +
                #tf.logging.info("reward_de: "+str(reward_de))

                model_sentiment.run_train_step(sess_sen, sentiment_batch,
                                               reward)
Ejemplo n.º 5
0
def main(unused_argv):
  if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
    raise Exception("Problem with flags: %s" % unused_argv)

  tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
  tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

  # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
  FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
  if not os.path.exists(FLAGS.log_root):
    if FLAGS.mode=="train":
      os.makedirs(FLAGS.log_root)
    else:
      raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

  vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary


  # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
  hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps']
  hps_dict = {}
  for key,val in FLAGS.__flags.items(): # for each flag
    if key in hparam_list: # if it's in the list
      hps_dict[key] = val # add it to the dict
  hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
                 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len']
  hps_dict = {}
  for key, val in FLAGS.__flags.items():  # for each flag
      if key in hparam_list:  # if it's in the list
          hps_dict[key] = val  # add it to the dict
  hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  # Create a batcher object that will create minibatches of data
  batcher = GenBatcher(vocab, hps_generator)




  tf.set_random_seed(111) # a seed value for randomness





  if hps_generator.mode == 'train':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)
    print("Start pre-training generator......")
    run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until 

    print("Generating negetive examples......")
    generated.generator_whole_negetive_example()
    generated.generator_test_negetive_example()

    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "train/generated_samples_positive/*", "train/generated_samples_negetive/*", "test/generated_samples_positive/*", "test/generated_samples_negetive/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    print("Start pre-training discriminator......")
    #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
    run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)

    util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
    
    generated.generator_sample_example("sample_temp_positive", "sample_temp_negetive", 1000)

    generated.generator_test_sample_example("test_sample_temp_positive",
                                       "test_sample_temp_negetive",
                                       200)
    generated.generator_test_max_example("test_max_temp_positive",
                                       "test_max_temp_negetive",
                                       200)
    tf.logging.info("true data diversity: ")
    eva = Evaluate()
    eva.diversity_evaluate("test_sample_temp_positive" + "/*")



    print("Start adversial training......")
    whole_decay = False
    for epoch in range(1):
        batches = batcher.get_batches(mode='train')
        for step in range(int(len(batches)/1000)):

            run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
            generated.generator_sample_example("sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 1000)
            #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

            tf.logging.info("test performance: ")
            tf.logging.info("epoch: "+str(epoch)+" step: "+str(step))
            generated.generator_test_sample_example(
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200)
            generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                                            "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive",
                                            200)

            dis_batcher.train_queue = []
            dis_batcher.train_queue = []
            for i in range(epoch+1):
              for j in range(step+1):
                dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*")
                dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negetive/*")
            dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True)

            #dis_batcher.valid_batch = dis_batcher.train_batch
            whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"),
                                                  sess_dis, saver_dis, train_dir_dis, whole_decay)

  '''elif hps_generator.mode == 'decode':
    decode_model_hps = hps_generator  # This will be the hyperparameters for the decoder model
    model = Generator(decode_model_hps, vocab)
    generated = Generated_sample(model, vocab, batcher)
    bleu_score = generated.compute_BLEU()'=
    tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen'''

  else:
Ejemplo n.º 6
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if "train" in FLAGS.mode:
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # print('FLAGS.flag_values_dict() ->', FLAGS.flag_values_dict())
    flags_dict = FLAGS.flag_values_dict()

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps'
    ]

    hps_dict = {}
    for key, val in flags_dict.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    print('hps_dict ->', json.dumps(hps_dict, ensure_ascii=False))
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'max_enc_sen_num', 'max_enc_seq_len'
    ]
    hps_dict = {}
    for key, val in flags_dict.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # # test
    # model_dis = Discriminator(hps_discriminator, vocab)
    # model_dis.build_graph()
    # sys.exit(0)
    # # test

    print('before load batcher...')
    # Create a batcher object that will create minibatches of data
    batcher = GenBatcher(vocab, hps_generator)
    print('after load batcher...')

    tf.set_random_seed(111)  # a seed value for randomness

    if hps_generator.mode == 'adversarial_train':
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
        generated = Generated_sample(model, vocab, batcher, sess_ge)

        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)

        util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        if not os.path.exists("MLE"): os.mkdir("MLE")

        print("evaluate the diversity of MLE (decode based on sampling)")
        generated.generator_test_sample_example("MLE/" + "MLE_sample_positive",
                                                "MLE/" + "MLE_sample_negative",
                                                200)

        print(
            "evaluate the diversity of MLE (decode based on max probability)")
        generated.generator_test_max_example("MLE/" + "MLE_max_temp_positive",
                                             "MLE/" + "MLE_max_temp_negative",
                                             200)

        print("Start adversarial  training......")
        if not os.path.exists("train_sample_generated"):
            os.mkdir("train_sample_generated")
        if not os.path.exists("test_max_generated"):
            os.mkdir("test_max_generated")
        if not os.path.exists("test_sample_generated"):
            os.mkdir("test_sample_generated")

        whole_decay = False
        for epoch in range(10):
            batches = batcher.get_batches(mode='train')
            for step in range(int(len(batches) / 1000)):

                run_train_generator(
                    model, model_dis, sess_dis, batcher, dis_batcher,
                    batches[step * 1000:(step + 1) * 1000], sess_ge, saver_ge,
                    train_dir_ge, generated
                )  # (model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
                generated.generator_sample_example(
                    "train_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "train_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    1000)
                # generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

                tf.logging.info("test performance: ")
                tf.logging.info("epoch: " + str(epoch) + " step: " + str(step))
                print(
                    "evaluate the diversity of DP-GAN (decode based on  max probability)"
                )
                generated.generator_test_sample_example(
                    "test_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "test_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    200)
                print(
                    "evaluate the diversity of DP-GAN (decode based on sampling)"
                )
                generated.generator_test_max_example(
                    "test_max_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "test_max_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    200)

                dis_batcher.train_queue = []
                dis_batcher.train_queue = []
                for i in range(epoch + 1):
                    for j in range(step + 1):
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_positive/*")
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_negative/*")
                dis_batcher.train_batch = dis_batcher.create_batches(
                    mode="train", shuffleis=True)

                # dis_batcher.valid_batch = dis_batcher.train_batch
                whole_decay = run_train_discriminator(
                    model_dis, 5, dis_batcher,
                    dis_batcher.get_batches(mode="train"), sess_dis, saver_dis,
                    train_dir_dis, whole_decay)

    elif hps_generator.mode == 'train_generator':
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
        generated = Generated_sample(model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        # this is an infinite loop until
        run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge,
                                train_dir_ge, generated)

        print("Generating negative examples......")
        generated.generator_train_negative_example()
        generated.generator_test_negative_example()
    elif hps_generator.mode == 'train_discriminator':
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)

        # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")

        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)
        print("Start pre-training discriminator......")
        # run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
        if not os.path.exists("discriminator_result"):
            os.mkdir("discriminator_result")
        run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis,
                                    saver_dis, train_dir_dis)
Ejemplo n.º 7
0
def main(unused_argv):
  if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
    raise Exception("Problem with flags: %s" % unused_argv)

  tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
  tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

  # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
  FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
  if not os.path.exists(FLAGS.log_root):
    if "train" in FLAGS.mode:
      os.makedirs(FLAGS.log_root)
    else:
      raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

  vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary


  # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
  hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps']
  hps_dict = {}
  for key,val in FLAGS.__flags.items(): # for each flag
    if key in hparam_list: # if it's in the list
      hps_dict[key] = val # add it to the dict
  hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
                 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len']
  hps_dict = {}
  for key, val in FLAGS.__flags.items():  # for each flag
      if key in hparam_list:  # if it's in the list
          hps_dict[key] = val  # add it to the dict
  hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  # Create a batcher object that will create minibatches of data
  batcher = GenBatcher(vocab, hps_generator)




  tf.set_random_seed(111) # a seed value for randomness





  if hps_generator.mode == 'adversarial_train':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)


    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    
    
    util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")

    util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
    
    
   
    if not os.path.exists("MLE"): os.mkdir("MLE")

    print("evaluate the diversity of MLE (decode based on sampling)")
    generated.generator_test_sample_example("MLE/"+"MLE_sample_positive",
                                       "MLE/"+"MLE_sample_negative",
                                       200)
                                       
    print("evaluate the diversity of MLE (decode based on max probability)")
    generated.generator_test_max_example("MLE/"+"MLE_max_temp_positive",
                                       "MLE/"+"MLE_max_temp_negative",
                                       200)
  

    print("Start adversarial  training......")
    if not os.path.exists("train_sample_generated"): os.mkdir("train_sample_generated")
    if not os.path.exists("test_max_generated"): os.mkdir("test_max_generated")
    if not os.path.exists("test_sample_generated"): os.mkdir("test_sample_generated")
    
    
    
    whole_decay = False
    for epoch in range(10):
        batches = batcher.get_batches(mode='train')
        for step in range(int(len(batches)/1000)):

            run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
            generated.generator_sample_example("train_sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "train_sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negative", 1000)
            #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

            tf.logging.info("test performance: ")
            tf.logging.info("epoch: "+str(epoch)+" step: "+str(step))
            print("evaluate the diversity of DP-GAN (decode based on  max probability)")
            generated.generator_test_sample_example(
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 200)
            print("evaluate the diversity of DP-GAN (decode based on sampling)")
            generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                                            "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                                            200)

            dis_batcher.train_queue = []
            dis_batcher.train_queue = []
            for i in range(epoch+1):
              for j in range(step+1):
                dis_batcher.train_queue += dis_batcher.fill_example_queue("train_sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*")
                dis_batcher.train_queue += dis_batcher.fill_example_queue("train_sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negative/*")
            dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True)

            #dis_batcher.valid_batch = dis_batcher.train_batch
            whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"),
                                                  sess_dis, saver_dis, train_dir_dis, whole_decay)

  elif hps_generator.mode == 'train_generator':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)
    print("Start pre-training generator......")
    run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until 

    print("Generating negative examples......")
    generated.generator_train_negative_example()
    generated.generator_test_negative_example()
  elif hps_generator.mode == 'train_discriminator':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    
    #util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")

    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    print("Start pre-training discriminator......")
    #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
    if not os.path.exists("discriminator_result"): os.mkdir("discriminator_result")
    run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)
Ejemplo n.º 8
0
def main():
    ################################
    ## 第一模块:数据准备工作
    data_ = data.Data(args.data_dir, args.vocab_size)

    # 对ICD tree 处理
    parient_children, level2_parients, leafNodes, adj, node2id, hier_dicts = utils.build_tree(
        os.path.join(args.data_dir, 'note_labeled_v2.csv'))
    graph = utils.generate_graph(parient_children, node2id)
    args.node2id = node2id
    args.id2node = {id: node for node, id in node2id.items()}
    args.adj = torch.Tensor(adj).long().to(args.device)
    # args.leafNodes=leafNodes
    args.hier_dicts = hier_dicts
    # args.level2_parients=level2_parients
    #print('836:',args.id2node.get(836),args.id2node.get(0))

    # TODO batcher对象的细节
    g_batcher = GenBatcher(data_, args)

    #################################
    ## 第二模块: 创建G模型,并预训练 G模型
    # TODO Generator对象的细节
    gen_model_eval = Generator(args, data_, graph, level2_parients)
    gen_model_target = Generator(args, data_, graph, level2_parients)
    gen_model_target.eval()
    print(gen_model_eval)

    # for name,param in gen_model_eval.named_parameters():
    #     print(name,param.size(),type(param))
    buffer = ReplayBuffer(capacity=100000)
    gen_model_eval.to(args.device)
    gen_model_target.to(args.device)

    # TODO generated 对象的细节

    # 预训练 G模型
    #pre_train_generator(gen_model,g_batcher,10)

    #####################################
    ## 第三模块: 创建 D模型,并预训练 D模型
    d_model = Discriminator(args)
    d_model.to(args.device)

    # 预训练 D模型
    #pre_train_discriminator(d_model,d_batcher,25)

    ########################################
    ## 第四模块: 交替训练G和D模型

    #将评估结果写入文件中
    f = open('valid_result.csv', 'w')
    writer = csv.writer(f)
    writer.writerow([
        'avg_micro_p', 'avg_macro_p', 'avg_micro_r,avg_macro_r',
        'avg_micro_f1', 'avg_macro_f1', 'avg_micro_auc_roc',
        'avg_macro_auc_roc'
    ])
    epoch_f = []
    for epoch in range(args.num_epochs):
        batches = g_batcher.get_batches(mode='train')
        print('number of batches:', len(batches))
        for step in range(len(batches)):
            #print('step:',step)
            current_batch = batches[step]
            ehrs = [example.ehr for example in current_batch]
            ehrs = torch.Tensor(ehrs).long().to(args.device)

            hier_labels = [example.hier_labels for example in current_batch]

            true_labels = []

            # 对hier_labels进行填充
            for i in range(len(hier_labels)):  # i为样本索引
                for j in range(len(hier_labels[i])):  # j为每个样本的每条路径索引
                    if len(hier_labels[i][j]) < 4:
                        hier_labels[i][j] = hier_labels[i][j] + [0] * (
                            4 - len(hier_labels[i][j]))
                # if len(hier_labels[i]) < args.k:
                #     for time in range(args.k - len(hier_labels[i])):
                #         hier_labels[i].append([0] * args.hops)

            for sample in hier_labels:
                #print('sample:',sample)
                true_labels.append([row[1] for row in sample])

            predHierLabels, batchStates_n, batchHiddens_n = generator.generated_negative_samples(
                gen_model_eval, d_model, ehrs, hier_labels, buffer)

            #true_labels = [example.labels for example in current_batch]

            _, _, avgJaccard = full_eval.process_labels(
                predHierLabels, true_labels, args)

            # G生成训练D的positive samples
            batchStates_p, batchHiddens_p = generator.generated_positive_samples(
                gen_model_eval, ehrs, hier_labels, buffer)

            # 训练 D网络
            #d_loss=train_discriminator(d_model,batchStates_n,batchHiddens_n,batchStates_p,batchHiddens_p,mode=args.mode)

            # 训练 G模型
            #for g_epoch in range(10):
            g_loss = train_generator(gen_model_eval,
                                     gen_model_target,
                                     d_model,
                                     batchStates_n,
                                     batchHiddens_n,
                                     buffer,
                                     mode=args.mode)

            print('batch_number:{}, avgJaccard:{:.4f}, g_loss:{:.4f}'.format(
                step, avgJaccard, g_loss))

        # #每经过一个epoch 之后分别评估G 模型的表现以及D模型的表现(在验证集上的表现)
        avg_micro_f1 = evaluate(g_batcher,
                                gen_model_eval,
                                d_model,
                                buffer,
                                writer,
                                flag='valid')
        epoch_f.append(avg_micro_f1)

    # 画图
    # plot results
    window = int(args.num_epochs / 20)
    print('window:', window)
    fig, ((ax1), (ax2)) = plt.subplots(2, 1, sharey=True, figsize=[9, 9])
    rolling_mean = pd.Series(epoch_f).rolling(window).mean()
    std = pd.Series(epoch_f).rolling(window).std()
    ax1.plot(rolling_mean)
    ax1.fill_between(range(len(epoch_f)),
                     rolling_mean - std,
                     rolling_mean + std,
                     color='orange',
                     alpha=0.2)
    ax1.set_title(
        'Episode Length Moving Average ({}-episode window)'.format(window))
    ax1.set_xlabel('Epoch Number')
    ax1.set_ylabel('F1')

    ax2.plot(epoch_f)
    ax2.set_title('Performance on valid set')
    ax2.set_xlabel('Epoch Number')
    ax2.set_ylabel('F1')

    fig.tight_layout(pad=2)
    plt.show()
    fig.savefig('results.png')

    f.close()
Ejemplo n.º 9
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    tf.set_random_seed(111)  # a seed value for randomness

    if hps_generator.mode == 'train':

        print("Start pre-training......")
        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)
        print("Start pre-training classification......")
        #run_pre_train_classification(model_class, cla_batcher, 10, sess_cls, saver_cls, train_dir_cls)
        #generated = Generate_training_sample(model_class, vocab, cla_batcher, sess_cls)

        #print("Generating training examples......")
        #generated.generate_training_example("train")
        #generated.generator_validation_example("valid")

        model_sentiment = Sentimentor(hps_generator, vocab)
        sentiment_batcher = SenBatcher(hps_generator, vocab)
        sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor(
            model_sentiment)
        #run_pre_train_sentimentor(model_sentiment,sentiment_batcher,1,sess_sen,saver_sen,train_dir_sen)
        sentiment_generated = Generate_non_sentiment_weight(
            model_sentiment, vocab, sentiment_batcher, sess_sen)
        #sentiment_generated.generate_training_example("train_sentiment")
        #sentiment_generated.generator_validation_example("valid_sentiment")

        model = Generator(hps_generator, vocab)
        # Create a batcher object that will create minibatches of data
        batcher = GenBatcher(vocab, hps_generator)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)

        util.load_ckpt(saver_sen, sess_sen, ckpt_dir="train-sentimentor")

        util.load_ckpt(saver_cls, sess_cls, ckpt_dir="train-classification")

        generated = Generated_sample(model, vocab, batcher, sess_ge)
        #print("Start pre-training generator......")
        run_pre_train_generator(
            model, batcher, 4, sess_ge, saver_ge, train_dir_ge, generated,
            model_class, sess_cls,
            cla_batcher)  # this is an infinite loop until interrupted

        #generated.generator_validation_negetive_example("temp_negetive", batcher, model_class,sess_cls,cla_batcher) # batcher, model_class, sess_cls, cla_batcher
        #generated.generator_validation_positive_example(
        #    "temp_positive", batcher, model_class,sess_cls,cla_batcher)

        loss_window = 0
        t0 = time.time()
        print("begin dual learning:")
        for epoch in range(30):
            batches = batcher.get_batches(mode='train')
            for i in range(len(batches)):
                current_batch = copy.deepcopy(batches[i])
                sentiment_batch = batch_sentiment_batch(
                    current_batch, sentiment_batcher)
                result = model_sentiment.max_generator(sess_sen,
                                                       sentiment_batch)
                weight = result['generated']
                current_batch.weight = weight
                sentiment_batch.weight = weight

                cla_batch = batch_classification_batch(current_batch, batcher,
                                                       cla_batcher)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)

                cc = SmoothingFunction()

                reward_sentiment = 1 - np.abs(0.5 - result['y_pred_auc'])
                reward_BLEU = []
                for k in range(FLAGS.batch_size):
                    reward_BLEU.append(
                        sentence_bleu(
                            [current_batch.original_reviews[k].split()],
                            cla_batch.original_reviews[k].split(),
                            smoothing_function=cc.method1))

                reward_BLEU = np.array(reward_BLEU)

                reward_de = (2 / (1.0 / (1e-6 + reward_sentiment) + 1.0 /
                                  (1e-6 + reward_BLEU)))

                result = model.run_train_step(sess_ge, current_batch)
                train_step = result[
                    'global_step']  # we need this to update our running average loss
                loss = result['loss']
                loss_window += loss
                if train_step % 100 == 0:
                    t1 = time.time()
                    tf.logging.info(
                        'seconds for %d training generator step: %.3f ',
                        train_step, (t1 - t0) / 100)
                    t0 = time.time()
                    tf.logging.info('loss: %f', loss_window /
                                    100)  # print the loss to screen
                    loss_window = 0.0
                if train_step % 10000 == 0:
                    #bleu_score = generatored.compute_BLEU(str(train_step))
                    #tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen
                    generated.generator_validation_negetive_example(
                        "valid-generated-transfer/" + str(epoch) +
                        "epoch_step" + str(train_step) + "_temp_positive",
                        batcher, model_class, sess_cls, cla_batcher)
                    generated.generator_validation_positive_example(
                        "valid-generated/" + str(epoch) + "epoch_step" +
                        str(train_step) + "_temp_positive", batcher,
                        model_class, sess_cls, cla_batcher)
                    #saver_ge.save(sess, train_dir + "/model", global_step=train_step)

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_sentiment = result['y_pred_auc']
                reward_result_bleu = np.array(bleu)

                reward_result = (2 / (1.0 /
                                      (1e-6 + reward_result_sentiment) + 1.0 /
                                      (1e-6 + reward_result_bleu)))

                current_batch.score = 1 - current_batch.score

                result = model.max_generator(sess_ge, current_batch)

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_transfer_sentiment = result['y_pred_auc']
                reward_result_transfer_bleu = np.array(bleu)

                reward_result_transfer = (
                    2 / (1.0 /
                         (1e-6 + reward_result_transfer_sentiment) + 1.0 /
                         (1e-6 + reward_result_transfer_bleu)))

                #tf.logging.info("reward_nonsentiment: "+str(reward_sentiment) +" output_original_sentiment: "+str(reward_result_sentiment)+" output_original_bleu: "+str(reward_result_bleu))

                reward = reward_result_transfer  #reward_de + reward_result_sentiment +
                #tf.logging.info("reward_de: "+str(reward_de))

                model_sentiment.run_train_step(sess_sen, sentiment_batch,
                                               reward)

    elif hps_generator.mode == 'decode':
        decode_model_hps = hps_generator  # This will be the hyperparameters for the decoder model
        #model = Generator(decode_model_hps, vocab)
        #generated = Generated_sample(model, vocab, batcher)
        #bleu_score = generated.compute_BLEU()
        #tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen

    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
Ejemplo n.º 10
0
def main(argv):
    tf.set_random_seed(111)  # a seed value for randomness

    # Create a batcher object that will create minibatches of data
    # TODO change to pass number

    # --------------- building graph ---------------
    hparam_gen = [
        'mode',
        'model_dir',
        'adagrad_init_acc',
        'steps_per_checkpoint',
        'batch_size',
        'beam_size',
        'cov_loss_wt',
        'coverage',
        'emb_dim',
        'rand_unif_init_mag',
        'gen_vocab_file',
        'gen_vocab_size',
        'hidden_dim',
        'gen_lr',
        'gen_max_gradient',
        'max_dec_steps',
        'max_enc_steps',
        'min_dec_steps',
        'trunc_norm_init_std',
        'single_pass',
        'log_root',
        'data_path',
    ]

    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_gen:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_gen = namedtuple("HParams4Gen", hps_dict.keys())(**hps_dict)

    print("Building vocabulary for generator ...")
    gen_vocab = Vocab(join_path(hps_gen.data_path, hps_gen.gen_vocab_file),
                      hps_gen.gen_vocab_size)

    hparam_dis = [
        'mode',
        'vocab_type',
        'model_dir',
        'dis_vocab_size',
        'steps_per_checkpoint',
        'learning_rate_decay_factor',
        'dis_vocab_file',
        'num_class',
        'layer_size',
        'conv_layers',
        'max_steps',
        'kernel_size',
        'early_stop',
        'pool_size',
        'pool_layers',
        'dis_max_gradient',
        'batch_size',
        'dis_lr',
        'lr_decay_factor',
        'cell_type',
        'max_enc_steps',
        'max_dec_steps',
        'single_pass',
        'data_path',
        'num_models',
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_dis:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_dis = namedtuple("HParams4Dis", hps_dict.keys())(**hps_dict)
    if hps_gen.gen_vocab_file == hps_dis.dis_vocab_file:
        hps_dis = hps_dis._replace(vocab_type="word")
        hps_dis = hps_dis._replace(layer_size=hps_gen.emb_dim)
        hps_dis = hps_dis._replace(dis_vocab_size=hps_gen.gen_vocab_size)
    else:
        hps_dis = hps_dis._replace(max_enc_steps=hps_dis.max_enc_steps * 2)
        hps_dis = hps_dis._replace(max_dec_steps=hps_dis.max_dec_steps * 2)
    if FLAGS.mode == "train_gan":
        hps_gen = hps_gen._replace(batch_size=hps_gen.batch_size *
                                   hps_dis.num_models)

    if FLAGS.mode != "pretrain_dis":
        with tf.variable_scope("generator"):
            generator = PointerGenerator(hps_gen, gen_vocab)
            print("Building generator graph ...")
            gen_decoder_scope = generator.build_graph()

    if FLAGS.mode != "pretrain_gen":
        print("Building vocabulary for discriminator ...")
        dis_vocab = Vocab(join_path(hps_dis.data_path, hps_dis.dis_vocab_file),
                          hps_dis.dis_vocab_size)
    if FLAGS.mode in ['train_gan', 'pretrain_dis']:
        with tf.variable_scope("discriminator"), tf.device("/gpu:0"):
            discriminator = Seq2ClassModel(hps_dis)
            print("Building discriminator graph ...")
            discriminator.build_graph()

    hparam_gan = [
        'mode',
        'model_dir',
        'gan_iter',
        'gan_gen_iter',
        'gan_dis_iter',
        'gan_lr',
        'rollout_num',
        'sample_num',
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_gan:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_gan = namedtuple("HParams4GAN", hps_dict.keys())(**hps_dict)
    hps_gan = hps_gan._replace(mode="train_gan")
    if FLAGS.mode == 'train_gan':
        with tf.device("/gpu:0"):
            print("Creating rollout...")
            rollout = Rollout(generator, 0.8, gen_decoder_scope)

    # --------------- initializing variables ---------------
    all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) + \
        tf.get_collection_ref(tf.GraphKeys.WEIGHTS) + \
        tf.get_collection_ref(tf.GraphKeys.BIASES)
    sess = tf.Session(config=utils.get_config())
    sess.run(tf.variables_initializer(all_variables))
    if FLAGS.mode == "pretrain_gen":
        val_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.val_dir))
        model_dir = ensure_exists(join_path(FLAGS.model_dir, 'generator'))
        print("Restoring the generator model from the latest checkpoint...")
        gen_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[
                v for v in all_variables
                if "generator" in v.name and "GAN" not in v.name
            ])
        gen_dir = ensure_exists(join_path(FLAGS.model_dir, "generator"))
        # gen_dir = ensure_exists(FLAGS.model_dir)
        # temp_saver = tf.train.Saver(
        #     var_list=[v for v in all_variables if "generator" in v.name and "Adagrad" not in v.name])
        ckpt_path = utils.load_ckpt(gen_saver, sess, gen_dir)
        print('going to restore embeddings from checkpoint')
        if not ckpt_path:
            emb_path = join_path(FLAGS.model_dir, "generator", "init_embed")
            if emb_path:
                generator.saver.restore(
                    sess,
                    tf.train.get_checkpoint_state(
                        emb_path).model_checkpoint_path)
                print(
                    colored(
                        "successfully restored embeddings form %s" % emb_path,
                        'green'))
            else:
                print(
                    colored("failed to restore embeddings form %s" % emb_path,
                            'red'))

    elif FLAGS.mode in ["decode", "train_gan"]:
        print("Restoring the generator model from the best checkpoint...")
        dec_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        gan_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir))
        gan_val_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir,
                      FLAGS.val_dir))
        gan_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        gan_val_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        utils.load_ckpt(dec_saver, sess, val_dir,
                        (FLAGS.mode in ["train_gan", "decode"]))

    if FLAGS.mode in ["pretrain_dis", "train_gan"]:
        dis_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "discriminator" in v.name])
        dis_dir = ensure_exists(join_path(FLAGS.model_dir, 'discriminator'))
        ckpt = utils.load_ckpt(dis_saver, sess, dis_dir)
        if not ckpt:
            if hps_dis.vocab_type == "word":
                discriminator.init_emb(
                    sess, join_path(FLAGS.model_dir, "generator",
                                    "init_embed"))
            else:
                discriminator.init_emb(
                    sess,
                    join_path(FLAGS.model_dir, "discriminator", "init_embed"))

    # --------------- train models ---------------
    if FLAGS.mode != "pretrain_dis":
        gen_batcher_train = GenBatcher("train",
                                       gen_vocab,
                                       hps_gen,
                                       single_pass=hps_gen.single_pass)
        decoder = Decoder(sess, generator, gen_vocab)
        gen_batcher_val = GenBatcher("val",
                                     gen_vocab,
                                     hps_gen,
                                     single_pass=True)
        val_saver = tf.train.Saver(
            max_to_keep=10,
            var_list=[
                v for v in all_variables
                if "generator" in v.name and "GAN" not in v.name
            ])

    if FLAGS.mode != "pretrain_gen":
        dis_val_batch_size = hps_dis.batch_size * hps_dis.num_models \
            if hps_dis.mode == "train_gan" else hps_dis.batch_size * hps_dis.num_models * 2
        dis_batcher_val = DisBatcher(
            hps_dis.data_path,
            "eval",
            gen_vocab,
            dis_vocab,
            dis_val_batch_size,
            single_pass=True,
            max_art_steps=hps_dis.max_enc_steps,
            max_abs_steps=hps_dis.max_dec_steps,
        )

    if FLAGS.mode == "pretrain_gen":
        # get reload the
        print('Going to pretrain the generator')
        try:
            pretrain_generator(generator, gen_batcher_train, sess,
                               gen_batcher_val, gen_saver, model_dir,
                               val_saver, val_dir)
        except KeyboardInterrupt:
            tf.logging.info("Caught keyboard interrupt on worker....")

    elif FLAGS.mode == "pretrain_dis":
        print('Going to pretrain the discriminator')
        dis_batcher = DisBatcher(
            hps_dis.data_path,
            "decode",
            gen_vocab,
            dis_vocab,
            hps_dis.batch_size * hps_dis.num_models,
            single_pass=hps_dis.single_pass,
            max_art_steps=hps_dis.max_enc_steps,
            max_abs_steps=hps_dis.max_dec_steps,
        )
        try:
            pretrain_discriminator(sess, discriminator, dis_batcher_val,
                                   dis_vocab, dis_batcher, dis_saver)
        except KeyboardInterrupt:
            tf.logging.info("Caught keyboard interrupt on worker....")

    elif FLAGS.mode == "train_gan":
        gen_best_loss = get_best_loss_from_chpt(val_dir)
        gen_global_step = 0
        print('Going to tune the two using Gan')
        for i_gan in range(hps_gan.gan_iter):
            # Train the generator for one step
            g_losses = []
            current_speed = []
            for it in range(hps_gan.gan_gen_iter):
                start_time = time.time()
                batch = gen_batcher_train.next_batch()

                # generate samples
                enc_states, dec_in_state, n_samples, n_targets_padding_mask = decoder.mc_generate(
                    batch, include_start_token=True, s_num=hps_gan.sample_num)
                # get rewards for the samples
                n_rewards = rollout.get_reward(sess, gen_vocab, dis_vocab,
                                               batch, enc_states, dec_in_state,
                                               n_samples, hps_gan.rollout_num,
                                               discriminator)

                # fine tune the generator
                n_sample_targets = [samples[:, 1:] for samples in n_samples]
                n_targets_padding_mask = [
                    padding_mask[:, 1:]
                    for padding_mask in n_targets_padding_mask
                ]
                n_samples = [samples[:, :-1] for samples in n_samples]
                # sample_target_padding_mask = pad_sample(sample_target, gen_vocab, hps_gen)
                n_samples = [
                    np.where(
                        np.less(samples, hps_gen.gen_vocab_size), samples,
                        np.array([[gen_vocab.word2id(data.UNKNOWN_TOKEN)] *
                                  hps_gen.max_dec_steps] * hps_gen.batch_size))
                    for samples in n_samples
                ]
                results = generator.run_gan_batch(sess, batch, n_samples,
                                                  n_sample_targets,
                                                  n_targets_padding_mask,
                                                  n_rewards)

                gen_global_step = results["global_step"]

                # for visualization
                g_loss = results["loss"]
                if not math.isnan(g_loss):
                    g_losses.append(g_loss)
                else:
                    print(colored('a nan in gan loss', 'red'))
                current_speed.append(time.time() - start_time)

            # Test
            # if FLAGS.gan_gen_iter and (i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1):
            if i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1:
                print('Going to test the generator.')
                current_speed = sum(current_speed) / (len(current_speed) *
                                                      hps_gen.batch_size)
                everage_g_loss = sum(g_losses) / len(g_losses)
                # one more process hould be opened for the evaluation
                eval_loss, gen_best_loss = save_ckpt(
                    sess, generator, gen_best_loss, gan_dir, gan_saver,
                    gen_batcher_val, gan_val_dir, gan_val_saver,
                    gen_global_step)

                if eval_loss:
                    print("\nDashboard for " +
                          colored("GAN Generator", 'green') + " updated %s, "
                          "finished steps:\t%s\n"
                          "\tBatch size:\t%s\n"
                          "\tVocabulary size:\t%s\n"
                          "\tCurrent speed:\t%.4f seconds/article\n"
                          "\tAverage training loss:\t%.4f; "
                          "eval loss:\t%.4f" % (
                              datetime.datetime.now().strftime(
                                  "on %m-%d at %H:%M"),
                              gen_global_step,
                              FLAGS.batch_size,
                              hps_gen.gen_vocab_size,
                              current_speed,
                              everage_g_loss.item(),
                              eval_loss.item(),
                          ))

            # Train the discriminator
            print('Going to train the discriminator.')
            dis_best_loss = 1000
            dis_losses = []
            dis_accuracies = []
            for d_gan in range(hps_gan.gan_dis_iter):
                batch = gen_batcher_train.next_batch()
                enc_states, dec_in_state, k_samples_words, _ = decoder.mc_generate(
                    batch, s_num=hps_gan.sample_num)
                # shuould first tanslate to words to avoid unk
                articles_oovs = batch.art_oovs
                for samples_words in k_samples_words:
                    dec_batch_words = batch.target_batch
                    conditions_words = batch.enc_batch_extend_vocab
                    if hps_dis.vocab_type == "char":
                        samples = gen_vocab2dis_vocab(samples_words, gen_vocab,
                                                      articles_oovs, dis_vocab,
                                                      hps_dis.max_dec_steps,
                                                      STOP_DECODING)
                        dec_batch = gen_vocab2dis_vocab(
                            dec_batch_words, gen_vocab, articles_oovs,
                            dis_vocab, hps_dis.max_dec_steps, STOP_DECODING)
                        conditions = gen_vocab2dis_vocab(
                            conditions_words, gen_vocab, articles_oovs,
                            dis_vocab, hps_dis.max_enc_steps, PAD_TOKEN)
                    else:
                        samples = samples_words
                        dec_batch = dec_batch_words
                        conditions = conditions_words
                        # the unknown in target

                    inputs = np.concatenate([samples, dec_batch], 0)
                    conditions = np.concatenate([conditions, conditions], 0)

                    targets = [[1, 0] for _ in samples] + [[0, 1]
                                                           for _ in dec_batch]
                    targets = np.array(targets)
                    # randomize the samples
                    assert len(inputs) == len(conditions) == len(
                        targets
                    ), "lengthes of the inputs, conditions and targests should be the same."
                    indices = np.random.permutation(len(inputs))
                    inputs = np.split(inputs[indices], 2)
                    conditions = np.split(conditions[indices], 2)
                    targets = np.split(targets[indices], 2)
                    assert len(inputs) % 2 == 0, "the length should be mean"

                    results = discriminator.run_one_batch(
                        sess, inputs[0], conditions[0], targets[0])
                    dis_accuracies.append(results["accuracy"].item())
                    dis_losses.append(results["loss"].item())

                    results = discriminator.run_one_batch(
                        sess, inputs[1], conditions[1], targets[1])
                    dis_accuracies.append(results["accuracy"].item())

                ave_dis_acc = sum(dis_accuracies) / len(dis_accuracies)
                if d_gan == hps_gan.gan_dis_iter - 1:
                    if (sum(dis_losses) / len(dis_losses)) < dis_best_loss:
                        dis_best_loss = sum(dis_losses) / len(dis_losses)
                        checkpoint_path = ensure_exists(
                            join_path(hps_dis.model_dir,
                                      "discriminator")) + "/model.ckpt"
                        dis_saver.save(sess,
                                       checkpoint_path,
                                       global_step=results["global_step"])
                    print_dashboard("GAN Discriminator",
                                    results["global_step"].item(),
                                    hps_dis.batch_size, hps_dis.dis_vocab_size,
                                    results["loss"].item(), 0.00, 0.00, 0.00)
                    print("Average training accuracy: \t%.4f" % ave_dis_acc)

                if ave_dis_acc > 0.9:
                    break

    # --------------- decoding samples ---------------
    elif FLAGS.mode == "decode":
        print('Going to decode from the generator.')
        decoder.bs_decode(gen_batcher_train)
        print("Finished decoding..")
        # decode for generating corpus for discriminator

    sess.close()