def main():
    # load rhyme table
    table = np.load("./data/table.npy")
    np.random.seed(SEED)
    random.seed(SEED)

    # data loader
    # gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    input_data_loader = Input_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    D = Discriminator(SEQ_LENGTH, num_class, vocab_size, dis_emb_size, dis_filter_sizes, dis_num_filters, 0.2)
    G = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, table, has_input=True)

    # avoid occupy all the memory of the GPU
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # savers for different models
    saver_gen = tf.train.Saver()
    saver_dis = tf.train.Saver()
    saver_seqgan = tf.train.Saver()

    # gen_data_loader.create_batches(positive_file)
    input_data_loader.create_batches(x_file, y_file)
    log = open('./experiment-log.txt', 'w')
    #  pre-train generator
    if pre_train_gen_path:
        print("loading pretrain generator model...")
        log.write("loading pretrain generator model...")
        restore_model(G, sess, saver_gen, pre_train_gen_path)
        print("loaded")
    else:
        log.write('pre-training generator...\n')
        print('Start pre-training...')
        for epoch in range(PRE_GEN_NUM):
            s = time.time()
            # loss = pre_train_epoch(sess, G, gen_data_loader)
            loss = pre_train_epoch(sess, G, input_data_loader)
            print("Epoch ", epoch, " loss: ", loss)
            log.write("Epoch:\t" + str(epoch) + "\tloss:\t" + str(loss) + "\n")
            print("pre-train generator epoch time: ", time.time() - s, " s")
            best = 1000
            if loss < best:
                saver_gen.save(sess, "./model/pre_gen/pretrain_gen_best")
                best = loss
    dev_loader = Input_Data_loader(BATCH_SIZE)
    dev_loader.create_batches(dev_x, dev_y)

    if pre_train_dis_path:
        print("loading pretrain discriminator model...")
        log.write("loading pretrain discriminator model...")
        restore_model(D, sess, saver_dis, pre_train_dis_path)
        print("loaded")
    else:
        log.write('pre-training discriminator...\n')
        print("Start pre-train the discriminator")
        s = time.time()
        for epoch in range(PRE_DIS_NUM):
            # generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file)
            generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file, input_data_loader)
            # dis_data_loader.load_train_data(positive_file, negative_file)
            dis_data_loader.load_train_data(y_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        D.input_x: x_batch,
                        D.input_y: y_batch,
                        D.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, acc = sess.run([D.train_op, D.accuracy], feed)
            print("Epoch ", epoch, " Accuracy: ", acc)
            log.write("Epoch:\t" + str(epoch) + "\tAccuracy:\t" + str(acc) + "\n")
            best = 0
            # if epoch % 20  == 0 or epoch == PRE_DIS_NUM -1:
            #     print("saving at epoch: ", epoch)
            #     saver_dis.save(sess, "./model/per_dis/pretrain_dis", global_step=epoch)
            if acc > best:
                saver_dis.save(sess, "./model/pre_dis/pretrain_dis_best")
                best = acc
        print("pre-train discriminator: ", time.time() - s, " s")

    g_beta = G_beta(G, update_rate=0.8)

    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('Start adversarial training...\n')

    for total_batch in range(TOTAL_BATCH):
        s = time.time()
        for it in range(ADV_GEN_TIME):
            for i in range(input_data_loader.num_batch):
                input_x, target = input_data_loader.next_batch()
                samples = G.generate(sess, input_x)
                rewards = g_beta.get_reward(sess, samples, input_x, sample_time, D)
                avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
                print(" epoch : %d time : %di: %d avg %f" % (total_batch, it, i, avg))
                feed = {G.x: samples, G.rewards: rewards, G.inputs: input_x}
                _ = sess.run(G.g_update, feed_dict=feed)
        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
            buffer = 'epoch:\t' + str(total_batch) + '\treward:\t' + str(avg) + '\n'
            print('total_batch: ', total_batch, 'average reward: ', avg)
            log.write(buffer)

            saver_seqgan.save(sess, "./model/seq_gan/seq_gan", global_step=total_batch)

        g_beta.update_params()

        # train the discriminator
        for it in range(ADV_GEN_TIME // GEN_VS_DIS_TIME):
            # generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file)
            generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file, input_data_loader)
            dis_data_loader.load_train_data(y_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for batch in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        D.input_x: x_batch,
                        D.input_y: y_batch,
                        D.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(D.train_op, feed_dict=feed)
        print("Adversarial Epoch consumed: ", time.time() - s, " s")

    # final generation
    print("Finished")
    log.close()
    # save model

    print("Training Finished, starting to generating test ")
    test_loader = Input_Data_loader(batch_size=BATCH_SIZE)
    test_loader.create_batches(test_x, test_y)

    generate_samples(sess, G, BATCH_SIZE, test_num, test_file + "_final.txt", test_loader)
示例#2
0
def main():
    # set random seed (may important to the result)
    np.random.seed(SEED)
    random.seed(SEED)

    # data loader
    # gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    input_data_loader = Input_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    D = Discriminator(SEQ_LENGTH, num_class, vocab_size, dis_emb_size,
                      dis_filter_sizes, dis_num_filters, 0.2)
    G = Generator(vocab_size,
                  BATCH_SIZE,
                  EMB_DIM,
                  HIDDEN_DIM,
                  SEQ_LENGTH,
                  START_TOKEN,
                  has_input=True)

    # avoid occupy all the memory of the GPU
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # sess
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # change the train data to real poems  to be done
    # gen_data_loader.create_batches(positive_file)
    input_data_loader.create_batches(x_file, y_file)
    log = open('./experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training generator...\n')
    for epoch in range(PRE_EPOCH_NUM):
        s = time.time()
        # loss = pre_train_epoch(sess, G, gen_data_loader)
        loss = pre_train_epoch_v2(sess, G, input_data_loader)
        print("Epoch ", epoch, " loss: ", loss)
        print("pre-train generator epoch time: ", time.time() - s, " s")
    dev_loader = Input_Data_loader(BATCH_SIZE)
    dev_loader.create_batches(dev_x, dev_y)
    generate_samples_v2(sess, G, BATCH_SIZE, dev_num,
                        dev_file + "_no_adv" + ".txt", dev_loader)
    bleu = calc_bleu(dev_y, dev_file + "_no_adv.txt")
    print("pre-train bleu: ", bleu)
    log.write("pre-train bleu: %f " % bleu)
    print("Start pre-train the discriminator")
    s = time.time()
    for _ in range(PRE_DIS_NUM):
        # generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file)
        generate_samples_v2(sess, G, BATCH_SIZE, generated_num, negative_file,
                            input_data_loader)
        # dis_data_loader.load_train_data(positive_file, negative_file)
        dis_data_loader.load_train_data(y_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    D.input_x: x_batch,
                    D.input_y: y_batch,
                    D.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, acc = sess.run([D.train_op, D.accuracy], feed)
            print(acc)
    print("pretrain discriminator: ", time.time() - s, " s")
    g_beta = G_beta(G, update_rate=0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')

    for total_batch in range(TOTAL_BATCH):
        # train generator once
        s = time.time()
        for it in range(1):
            # samples = G.generate(sess)
            # print(input_data_loader.get_all().shape)
            # input_data_loader.reset_pointer()
            # samples = []
            # for i in range(input_data_loader.num_batch):
            input_x = input_data_loader.next_batch()[0]
            samples = G.generate_v2(sess, input_x)
            # print(sample)
            # print(samples)
            rewards = g_beta.get_reward(sess, samples, sample_time, D)
            feed = {G.x: samples, G.rewards: rewards, G.inputs: input_x}
            _ = sess.run(G.g_update, feed_dict=feed)
        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            # generate_samples(sess, G, BATCH_SIZE, generated_num, eval_file)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
            buffer = 'epoch:\t' + str(total_batch) + '\treward:\t' + str(
                avg) + '\n'
            print('total_batch: ', total_batch, 'average reward: ', avg)
            log.write(buffer)
            print("generating dev sentences")

            generate_samples_v2(sess, G, BATCH_SIZE, dev_num,
                                dev_file + "_" + str(total_batch) + ".txt",
                                dev_loader)
            bleu = calc_bleu(dev_y, dev_file + "_" + str(total_batch) + ".txt")
            print("dev bleu: ", bleu)

            log.write("bleu: %.5f \n" % bleu)
        # update G_beta with weight decay
        g_beta.update_params()

        # train the discriminator
        for it in range(DIS_VS_GEN_TIME):
            # generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file)
            generate_samples_v2(sess, G, BATCH_SIZE, generated_num,
                                negative_file, input_data_loader)
            dis_data_loader.load_train_data(y_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for batch in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        D.input_x: x_batch,
                        D.input_y: y_batch,
                        D.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(D.train_op, feed_dict=feed)
        print("Adversarial Epoch consumed: ", time.time() - s, " s")
    # finnal generation
    print("Finished")
    log.close()
    # save model

    print("Training Finished, starting to generating test ")
    test_loader = Input_Data_loader(batch_size=BATCH_SIZE)
    test_loader.create_batches(test_x, test_y)

    generate_samples_v2(sess, G, BATCH_SIZE, test_num,
                        test_file + "_final.txt", test_loader)
示例#3
0
        paragraph = np.asarray(paragraph)
        paragraph = np.transpose(paragraph, (1, 0, 2))
        generated_samples.append(paragraph)

    with open(output_file, 'w') as fout:
        for batch in generated_samples:
            for lyrics in batch:
                for line in lyrics:
                    buffer = ' '.join([str(x) for x in line]) + '\n'
                    fout.write(buffer)
                fout.write('\n')
    return generated_samples


if __name__ == '__main__':
    # load rhyme table
    table = np.load("./data/_table.npy")
    G = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, table, mode='infer',
                  has_input=True)
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    restore_model(G, sess, saver, model_path=model_path)
    test_loader = Input_Data_loader(BATCH_SIZE)
    test_loader.create_batches(test_x, test_y)
    print("generating...")
    # generating according to input
    ret = generate_paragraph(sess, G, BATCH_SIZE, generated_num, "generated_paragraph.txt", test_loader)
    print("finished")
    # input_x = input()