negative_file = os.path.join(out_dir, 'generator_sample.txt')
eval_file = os.path.join(out_dir, 'eval.txt')
generated_num = 10000

gen_data_loader = Gen_Data_loader(gen_batch_size)
likelihood_data_loader = Likelihood_data_loader(gen_batch_size)
dis_data_loader = Dis_dataloader()
vocab_size = 5000
seq_length = 20
best_score = 1000
target_params = pickle.load(open('save/target_params.pkl', 'rb'), encoding='bytes')
target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, 20, 0, target_params)
start_token = 0

# generator
generator = SeqGAN(seq_length, vocab_size, gen_emb_dim, gen_hidden_dim, start_token, oracle=True).to_gpu()
if args.gen:
    print(args.gen)
    serializers.load_hdf5(args.gen, generator)

# discriminator
discriminator = TextCNN(num_classes=2, vocab_size=vocab_size,
                        embedding_size=dis_embedding_dim,
                        filter_sizes=dis_filter_sizes, num_filters=dis_num_filters).to_gpu()
if args.dis:
    serializers.load_hdf5(args.dis, discriminator)

sess = tf.Session()
sess.run(tf.initialize_all_variables())
gen_data_loader.create_batches(positive_file)
generate_samples_pos(sess, target_lstm, 64, 10000, positive_file)
# load data arasuji loader
with open('dataset/arasuji.dat', 'rb') as f:
    arasuji = pickle.load(f)

train_num = len(arasuji.train_idx)
test_num = len(arasuji.test_idx)
print('train_num = {}'.format(train_num))
print('test_num = {}'.format(test_num))
vocab_size = 2000
seq_length = 40
start_token = 0

# generator
generator = SeqGAN(vocab_size=vocab_size, emb_dim=args.gen_emb_dim, hidden_dim=args.gen_hidden_dim,
                   sequence_length=seq_length, start_token=start_token, lstm_layer=args.num_lstm_layer,
                   dropout=True).to_gpu()
if args.gen:
    serializers.load_hdf5(args.gen, generator)

# discriminator
discriminator = TextCNN(num_classes=2, vocab_size=vocab_size, embedding_size=args.dis_embedding_dim,
                        filter_sizes=[int(n) for n in args.dis_filter_sizes.split(',')],
                        num_filters=[int(n) for n in args.dis_num_filters.split(',')]
                        ).to_gpu()
if args.dis:
    serializers.load_hdf5(args.dis, discriminator)

# set optimizer
gen_optimizer = optimizers.Adam(alpha=args.gen_lr)
gen_optimizer.setup(generator)
示例#3
0
encoder = SeqEncoder(vocab_size=vocab_size,
                     emb_dim=args.gen_emb_dim,
                     hidden_dim=args.gen_hidden_dim,
                     latent_dim=args.latent_dim,
                     sequence_length=seq_length,
                     tag_num=tag_dim).to_gpu()

if args.enc:
    serializers.load_hdf5(args.enc, encoder)

# generator
generator = SeqGAN(vocab_size=vocab_size,
                   emb_dim=args.gen_emb_dim,
                   hidden_dim=args.gen_hidden_dim,
                   sequence_length=seq_length,
                   start_token=start_token,
                   lstm_layer=args.num_lstm_layer,
                   dropout=args.dropout,
                   encoder=encoder,
                   latent_dim=args.latent_dim,
                   tag_dim=tag_dim).to_gpu()
if args.gen:
    serializers.load_hdf5(args.gen, generator)

# set optimizer
enc_optimizer = optimizers.Adam(alpha=args.gen_lr)
enc_optimizer.setup(encoder)
enc_optimizer.add_hook(chainer.optimizer.GradientClipping(args.gen_grad_clip))

gen_optimizer = optimizers.Adam(alpha=args.gen_lr)
gen_optimizer.setup(generator)
gen_optimizer.add_hook(chainer.optimizer.GradientClipping(args.gen_grad_clip))
示例#4
0
start_token = 0

# encoder
encoder = SeqEncoder(vocab_size=vocab_size,
                     emb_dim=args.gen_emb_dim,
                     hidden_dim=args.gen_hidden_dim,
                     sequence_length=seq_length).to_gpu()

if args.enc:
    serializers.load_hdf5(args.enc, encoder)

# generator
generator = SeqGAN(vocab_size=vocab_size,
                   emb_dim=args.gen_emb_dim,
                   hidden_dim=args.gen_hidden_dim,
                   sequence_length=seq_length,
                   start_token=start_token,
                   lstm_layer=args.num_lstm_layer,
                   dropout=args.dropout,
                   encoder=encoder).to_gpu()
if args.gen:
    serializers.load_hdf5(args.gen, generator)

# set optimizer
enc_optimizer = optimizers.Adam(alpha=args.gen_lr)
enc_optimizer.setup(encoder)
enc_optimizer.add_hook(chainer.optimizer.GradientClipping(args.gen_grad_clip))

gen_optimizer = optimizers.Adam(alpha=args.gen_lr)
gen_optimizer.setup(generator)
gen_optimizer.add_hook(chainer.optimizer.GradientClipping(args.gen_grad_clip))
train_num = len(train_comment_data)
test_num = len(test_comment_data)
vocab_size = 3000
seq_length = 30
start_token = 0

if args.ae_pretrain:
    encoder = SeqEncoder(vocab_size=vocab_size, emb_dim=args.gen_emb_dim, hidden_dim=args.gen_hidden_dim,
                         latent_dim=args.latent_dim, sequence_length=seq_length)
else:
    encoder = None

# generator
generator = SeqGAN(vocab_size=vocab_size, emb_dim=args.gen_emb_dim, hidden_dim=args.gen_hidden_dim,
                   sequence_length=seq_length, start_token=start_token, lstm_layer=args.num_lstm_layer,
                   dropout=args.dropout, free_pretrain=args.free_pretrain, encoder=encoder, tag_dim=len(tag_id)).to_gpu()
if args.gen:
    serializers.load_hdf5(args.gen, generator)

# discriminator
discriminator = TextCNN(num_classes=2, vocab_size=vocab_size, embedding_size=args.dis_embedding_dim,
                        filter_sizes=[int(n) for n in args.dis_filter_sizes.split(',')],
                        num_filters=[int(n) for n in args.dis_num_filters.split(',')]
                        ).to_gpu()
if args.dis:
    serializers.load_hdf5(args.dis, discriminator)

# set optimizer
if args.ae_pretrain:
    enc_optimizer = optimizers.Adam(alpha=args.gen_lr)