def generate_gan(sess, model, negative_size=FLAGS.gan_k): '''used for generate negative samples for the Discriminator''' samples = [] for _index, pair in enumerate(raw): if _index % 5000 == 0: print("have sampled %d pairs" % _index) q = pair[1] a = pair[2] distractor = pair[3] neg_alist_index = [i for i in range(len(alist))] sampled_index = np.random.choice(neg_alist_index, size=[FLAGS.pools_size], replace=False) pools = np.array( alist )[sampled_index] # it's possible that true positive samples are selected # TODO: remove true positives # [q, a, distractor, negative sample] canditates = data_helpers.loadCandidateSamples( q, a, distractor, pools, vocab, FLAGS.max_sequence_length_q, FLAGS.max_sequence_length_a) predicteds = [] for batch in data_helpers.batch_iter(canditates, batch_size=FLAGS.batch_size): feed_dict = { model.input_x_1: np.array(batch[:, 0].tolist()), model.input_x_2: np.array(batch[:, 1].tolist()), model.input_x_3: np.array(batch[:, 2].tolist()), model.input_x_4: np.array(batch[:, 3].tolist()) } predicted = sess.run(model.gan_score, feed_dict) predicteds.extend(predicted) predicteds = np.array(predicteds) * FLAGS.sampled_temperature predicteds -= np.max(predicteds) exp_rating = np.exp(predicteds) prob = exp_rating / np.sum(exp_rating) prob = np.nan_to_num(prob) + 1e-7 prob = prob / np.sum(prob) neg_samples = np.random.choice(pools, size=negative_size, p=prob, replace=False) for neg in neg_samples: samples.append( (encode_sent(vocab, q, FLAGS.max_sequence_length_q), encode_sent(vocab, a, FLAGS.max_sequence_length_a), encode_sent(vocab, distractor, FLAGS.max_sequence_length_a), encode_sent(vocab, neg, FLAGS.max_sequence_length_a))) return samples
def main(): with tf.Graph().as_default(): with tf.device("/gpu:1"): # embeddings param = None if len(FLAGS.pretrained_embeddings_path) > 0: print('loading pretrained embeddings...') param = embd else: print('using randomized embeddings...') param = np.random.uniform(-0.05, 0.05, (len(vocab), FLAGS.embedding_dim)) # models with tf.variable_scope('Dis'): discriminator = Discriminator.Discriminator( sequence_length_q=FLAGS.max_sequence_length_q, sequence_length_a=FLAGS.max_sequence_length_a, batch_size=FLAGS.batch_size, vocab_size=len(vocab), embedding_size=FLAGS.embedding_dim, hidden_size=FLAGS.hidden_size, l2_reg_lambda=FLAGS.l2_reg_lambda, learning_rate=FLAGS.learning_rate, dropout_keep_prob=FLAGS.dropout_keep_prob, padding_id=vocab[FLAGS.padding]) with tf.variable_scope('Gen'): generator = Generator.Generator( sequence_length_q=FLAGS.max_sequence_length_q, sequence_length_a=FLAGS.max_sequence_length_a, batch_size=FLAGS.batch_size, vocab_size=len(vocab), embedding_size=FLAGS.embedding_dim, hidden_size=FLAGS.hidden_size, l2_reg_lambda=FLAGS.l2_reg_lambda, sampled_temperature=FLAGS.sampled_temperature, learning_rate=FLAGS.learning_rate, dropout_keep_prob=FLAGS.dropout_keep_prob, padding_id=vocab[FLAGS.padding]) session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) sess = tf.Session(config=session_conf) with sess.as_default(), open(log_precision, "w") as log, open(log_loss, "w") as loss_log: # initialze or restore if len(FLAGS.pretrained_model_path) == 0: print('initializing model...') sess.run(tf.global_variables_initializer()) # pretrained embeddings or randomized embeddings sess.run( discriminator.embedding_init, feed_dict={discriminator.embedding_placeholder: param}) sess.run( generator.embedding_init, feed_dict={generator.embedding_placeholder: param}) else: print('loading pretrained model...') var_list = tf.global_variables() var_list = [ x for x in var_list if not x.name.startswith('Dis/output/Variable') ] var_list = [ x for x in var_list if not x.name.startswith('Gen/Variable') ] restore_op, feed_dict = tf.contrib.framework.assign_from_checkpoint( tf.train.latest_checkpoint( FLAGS.pretrained_model_path), var_list, True) sess.run(restore_op, feed_dict) # initial evaluation saver = tf.train.Saver(max_to_keep=None) # evaluation(sess, discriminator, log, saver, 0, 'dev', False) # evaluation(sess, generator, log, saver, 0, 'dev', False) baseline = 0.05 for i in range(FLAGS.num_epochs): # discriminator if i > 0: samples = generate_gan(sess, generator, FLAGS.gan_k) for _index, batch in enumerate( data_helpers.batch_iter( samples, num_epochs=FLAGS.d_epochs_num, batch_size=FLAGS.batch_size, shuffle=True)): feed_dict = { # [q, a, distractor, negative sample] discriminator.input_x_1: np.array(batch[:, 0].tolist()), discriminator.input_x_2: np.array(batch[:, 1].tolist()), discriminator.input_x_3: np.array(batch[:, 2].tolist()), discriminator.input_x_4: np.array(batch[:, 3].tolist()) } _, step, current_loss, accuracy, positive, negative = sess.run( [ discriminator.train_op, discriminator.global_step, discriminator.loss, discriminator.accuracy, discriminator.positive, discriminator.negative ], feed_dict) line = ( "%s: Dis step %d, loss %f with acc %f, positive %f negative %f" % (datetime.datetime.now().isoformat(), step, current_loss, accuracy, positive, negative)) if _index % 100 == 0: print(line) loss_log.write(line + "\n") loss_log.flush() evaluation(sess, discriminator, log, saver, i, 'dev', True, False) # generator baseline_avg = [] for g_epoch in range(FLAGS.g_epochs_num): for _index, pair in enumerate(raw): q = pair[1] a = pair[2] distractor = pair[3] # it's possible that true positive samples are selected neg_alist_index = [j for j in range(len(alist))] pos_num = min(4, len(raw_dict[q])) sampled_index = np.random.choice( neg_alist_index, size=FLAGS.pools_size - pos_num, replace=False) sampled_index = list(sampled_index) pools = np.array(alist)[sampled_index] # add the positive index positive_index = [ j for j in range(len(raw_dict[q])) ] positive_index = np.random.choice( positive_index, pos_num, replace=False).tolist() pools = np.concatenate( (pools, np.array(raw_dict[q])[positive_index])) samples = data_helpers.loadCandidateSamples( q, a, distractor, pools, vocab, FLAGS.max_sequence_length_q, FLAGS.max_sequence_length_a) predicteds = [] for batch in data_helpers.batch_iter( samples, batch_size=FLAGS.batch_size): feed_dict = { generator.input_x_1: np.array(batch[:, 0].tolist()), generator.input_x_2: np.array(batch[:, 1].tolist()), generator.input_x_3: np.array(batch[:, 2].tolist()), generator.input_x_4: np.array(batch[:, 3].tolist()) } predicted = sess.run(generator.gan_score, feed_dict) predicteds.extend(predicted) # generate FLAGS.gan_k negative samples predicteds = np.array( predicteds) * FLAGS.sampled_temperature predicteds -= np.max(predicteds) exp_rating = np.exp(predicteds) prob = exp_rating / np.sum(exp_rating) prob = np.nan_to_num(prob) + 1e-7 prob = prob / np.sum(prob) neg_index = np.random.choice(np.arange(len(pools)), size=FLAGS.gan_k, p=prob, replace=False) subsamples = np.array( data_helpers.loadCandidateSamples( q, a, distractor, pools[neg_index], vocab, FLAGS.max_sequence_length_q, FLAGS.max_sequence_length_a)) feed_dict = { discriminator.input_x_1: np.array(subsamples[:, 0].tolist()), discriminator.input_x_2: np.array(subsamples[:, 1].tolist()), discriminator.input_x_3: np.array(subsamples[:, 2].tolist()), discriminator.input_x_4: np.array(subsamples[:, 3].tolist()) } reward, l2_loss_d = sess.run( [discriminator.reward, discriminator.l2_loss], feed_dict) baseline_avg.append(np.mean(reward)) reward = reward - baseline samples = np.array(samples) feed_dict = { generator.input_x_1: np.array(samples[:, 0].tolist()), generator.input_x_2: np.array(samples[:, 1].tolist()), generator.input_x_3: np.array(samples[:, 2].tolist()), generator.input_x_4: np.array(samples[:, 3].tolist()), generator.neg_index: neg_index, generator.reward: reward } # should be softmax over all, but too computationally expensive _, step, current_loss, positive, negative, score12, score13, l2_loss_g = sess.run( [ generator.gan_updates, generator.global_step, generator.gan_loss, generator.positive, generator.negative, generator.score12, generator.score13, generator.l2_loss ], feed_dict) line = ( "%s: Gen step %d, loss %f l2 %f,%f positive %f negative %f, sample prob [%s, %f], reward [%f, %f]" % (datetime.datetime.now().isoformat(), step, current_loss, l2_loss_g, l2_loss_d, positive, negative, np.min(prob), np.max(prob), np.min(reward), np.max(reward))) if _index % 100 == 0: print(line) loss_log.write(line + "\n") loss_log.flush() evaluation(sess, generator, log, saver, i * FLAGS.g_epochs_num + g_epoch, 'dev', True, False) log.flush() baseline = np.mean(baseline_avg) # final evaluation evaluation(sess, discriminator, log, saver, -1, 'test', False, False, True) evaluation(sess, generator, log, saver, -1, 'test', False, False, True)