Example #1
0
def generate_gan(sess, model, negative_size=FLAGS.gan_k):
    '''used for generate negative samples for the Discriminator'''
    samples = []
    for _index, pair in enumerate(raw):
        if _index % 5000 == 0:
            print("have sampled %d pairs" % _index)
        q = pair[1]
        a = pair[2]
        distractor = pair[3]

        neg_alist_index = [i for i in range(len(alist))]
        sampled_index = np.random.choice(neg_alist_index,
                                         size=[FLAGS.pools_size],
                                         replace=False)
        pools = np.array(
            alist
        )[sampled_index]  # it's possible that true positive samples are selected
        # TODO: remove true positives

        # [q, a, distractor, negative sample]
        canditates = data_helpers.loadCandidateSamples(
            q, a, distractor, pools, vocab, FLAGS.max_sequence_length_q,
            FLAGS.max_sequence_length_a)
        predicteds = []
        for batch in data_helpers.batch_iter(canditates,
                                             batch_size=FLAGS.batch_size):
            feed_dict = {
                model.input_x_1: np.array(batch[:, 0].tolist()),
                model.input_x_2: np.array(batch[:, 1].tolist()),
                model.input_x_3: np.array(batch[:, 2].tolist()),
                model.input_x_4: np.array(batch[:, 3].tolist())
            }
            predicted = sess.run(model.gan_score, feed_dict)
            predicteds.extend(predicted)

        predicteds = np.array(predicteds) * FLAGS.sampled_temperature
        predicteds -= np.max(predicteds)
        exp_rating = np.exp(predicteds)
        prob = exp_rating / np.sum(exp_rating)
        prob = np.nan_to_num(prob) + 1e-7
        prob = prob / np.sum(prob)
        neg_samples = np.random.choice(pools,
                                       size=negative_size,
                                       p=prob,
                                       replace=False)
        for neg in neg_samples:
            samples.append(
                (encode_sent(vocab, q, FLAGS.max_sequence_length_q),
                 encode_sent(vocab, a, FLAGS.max_sequence_length_a),
                 encode_sent(vocab, distractor, FLAGS.max_sequence_length_a),
                 encode_sent(vocab, neg, FLAGS.max_sequence_length_a)))
    return samples
Example #2
0
def main():
    with tf.Graph().as_default():
        with tf.device("/gpu:1"):
            # embeddings
            param = None
            if len(FLAGS.pretrained_embeddings_path) > 0:
                print('loading pretrained embeddings...')
                param = embd
            else:
                print('using randomized embeddings...')
                param = np.random.uniform(-0.05, 0.05,
                                          (len(vocab), FLAGS.embedding_dim))

            # models
            with tf.variable_scope('Dis'):
                discriminator = Discriminator.Discriminator(
                    sequence_length_q=FLAGS.max_sequence_length_q,
                    sequence_length_a=FLAGS.max_sequence_length_a,
                    batch_size=FLAGS.batch_size,
                    vocab_size=len(vocab),
                    embedding_size=FLAGS.embedding_dim,
                    hidden_size=FLAGS.hidden_size,
                    l2_reg_lambda=FLAGS.l2_reg_lambda,
                    learning_rate=FLAGS.learning_rate,
                    dropout_keep_prob=FLAGS.dropout_keep_prob,
                    padding_id=vocab[FLAGS.padding])

            with tf.variable_scope('Gen'):
                generator = Generator.Generator(
                    sequence_length_q=FLAGS.max_sequence_length_q,
                    sequence_length_a=FLAGS.max_sequence_length_a,
                    batch_size=FLAGS.batch_size,
                    vocab_size=len(vocab),
                    embedding_size=FLAGS.embedding_dim,
                    hidden_size=FLAGS.hidden_size,
                    l2_reg_lambda=FLAGS.l2_reg_lambda,
                    sampled_temperature=FLAGS.sampled_temperature,
                    learning_rate=FLAGS.learning_rate,
                    dropout_keep_prob=FLAGS.dropout_keep_prob,
                    padding_id=vocab[FLAGS.padding])

            session_conf = tf.ConfigProto(
                allow_soft_placement=FLAGS.allow_soft_placement,
                log_device_placement=FLAGS.log_device_placement)
            sess = tf.Session(config=session_conf)
            with sess.as_default(), open(log_precision,
                                         "w") as log, open(log_loss,
                                                           "w") as loss_log:
                # initialze or restore
                if len(FLAGS.pretrained_model_path) == 0:
                    print('initializing model...')
                    sess.run(tf.global_variables_initializer())
                    # pretrained embeddings or randomized embeddings
                    sess.run(
                        discriminator.embedding_init,
                        feed_dict={discriminator.embedding_placeholder: param})
                    sess.run(
                        generator.embedding_init,
                        feed_dict={generator.embedding_placeholder: param})
                else:
                    print('loading pretrained model...')
                    var_list = tf.global_variables()
                    var_list = [
                        x for x in var_list
                        if not x.name.startswith('Dis/output/Variable')
                    ]
                    var_list = [
                        x for x in var_list
                        if not x.name.startswith('Gen/Variable')
                    ]
                    restore_op, feed_dict = tf.contrib.framework.assign_from_checkpoint(
                        tf.train.latest_checkpoint(
                            FLAGS.pretrained_model_path), var_list, True)
                    sess.run(restore_op, feed_dict)

                # initial evaluation
                saver = tf.train.Saver(max_to_keep=None)
                # evaluation(sess, discriminator, log, saver, 0, 'dev', False)
                # evaluation(sess, generator, log, saver, 0, 'dev', False)

                baseline = 0.05
                for i in range(FLAGS.num_epochs):
                    # discriminator
                    if i > 0:
                        samples = generate_gan(sess, generator, FLAGS.gan_k)
                        for _index, batch in enumerate(
                                data_helpers.batch_iter(
                                    samples,
                                    num_epochs=FLAGS.d_epochs_num,
                                    batch_size=FLAGS.batch_size,
                                    shuffle=True)):
                            feed_dict = {  # [q, a, distractor, negative sample]
                                discriminator.input_x_1:
                                np.array(batch[:, 0].tolist()),
                                discriminator.input_x_2:
                                np.array(batch[:, 1].tolist()),
                                discriminator.input_x_3:
                                np.array(batch[:, 2].tolist()),
                                discriminator.input_x_4:
                                np.array(batch[:, 3].tolist())
                            }
                            _, step, current_loss, accuracy, positive, negative = sess.run(
                                [
                                    discriminator.train_op,
                                    discriminator.global_step,
                                    discriminator.loss, discriminator.accuracy,
                                    discriminator.positive,
                                    discriminator.negative
                                ], feed_dict)

                            line = (
                                "%s: Dis step %d, loss %f with acc %f, positive %f negative %f"
                                % (datetime.datetime.now().isoformat(), step,
                                   current_loss, accuracy, positive, negative))
                            if _index % 100 == 0:
                                print(line)
                            loss_log.write(line + "\n")
                            loss_log.flush()
                        evaluation(sess, discriminator, log, saver, i, 'dev',
                                   True, False)

                    # generator
                    baseline_avg = []
                    for g_epoch in range(FLAGS.g_epochs_num):
                        for _index, pair in enumerate(raw):
                            q = pair[1]
                            a = pair[2]
                            distractor = pair[3]

                            # it's possible that true positive samples are selected
                            neg_alist_index = [j for j in range(len(alist))]
                            pos_num = min(4, len(raw_dict[q]))
                            sampled_index = np.random.choice(
                                neg_alist_index,
                                size=FLAGS.pools_size - pos_num,
                                replace=False)
                            sampled_index = list(sampled_index)
                            pools = np.array(alist)[sampled_index]
                            # add the positive index
                            positive_index = [
                                j for j in range(len(raw_dict[q]))
                            ]
                            positive_index = np.random.choice(
                                positive_index, pos_num,
                                replace=False).tolist()
                            pools = np.concatenate(
                                (pools, np.array(raw_dict[q])[positive_index]))

                            samples = data_helpers.loadCandidateSamples(
                                q, a, distractor, pools, vocab,
                                FLAGS.max_sequence_length_q,
                                FLAGS.max_sequence_length_a)
                            predicteds = []
                            for batch in data_helpers.batch_iter(
                                    samples, batch_size=FLAGS.batch_size):
                                feed_dict = {
                                    generator.input_x_1:
                                    np.array(batch[:, 0].tolist()),
                                    generator.input_x_2:
                                    np.array(batch[:, 1].tolist()),
                                    generator.input_x_3:
                                    np.array(batch[:, 2].tolist()),
                                    generator.input_x_4:
                                    np.array(batch[:, 3].tolist())
                                }
                                predicted = sess.run(generator.gan_score,
                                                     feed_dict)
                                predicteds.extend(predicted)

                            # generate FLAGS.gan_k negative samples
                            predicteds = np.array(
                                predicteds) * FLAGS.sampled_temperature
                            predicteds -= np.max(predicteds)
                            exp_rating = np.exp(predicteds)
                            prob = exp_rating / np.sum(exp_rating)
                            prob = np.nan_to_num(prob) + 1e-7
                            prob = prob / np.sum(prob)
                            neg_index = np.random.choice(np.arange(len(pools)),
                                                         size=FLAGS.gan_k,
                                                         p=prob,
                                                         replace=False)

                            subsamples = np.array(
                                data_helpers.loadCandidateSamples(
                                    q, a, distractor, pools[neg_index], vocab,
                                    FLAGS.max_sequence_length_q,
                                    FLAGS.max_sequence_length_a))
                            feed_dict = {
                                discriminator.input_x_1:
                                np.array(subsamples[:, 0].tolist()),
                                discriminator.input_x_2:
                                np.array(subsamples[:, 1].tolist()),
                                discriminator.input_x_3:
                                np.array(subsamples[:, 2].tolist()),
                                discriminator.input_x_4:
                                np.array(subsamples[:, 3].tolist())
                            }
                            reward, l2_loss_d = sess.run(
                                [discriminator.reward, discriminator.l2_loss],
                                feed_dict)
                            baseline_avg.append(np.mean(reward))
                            reward = reward - baseline

                            samples = np.array(samples)
                            feed_dict = {
                                generator.input_x_1:
                                np.array(samples[:, 0].tolist()),
                                generator.input_x_2:
                                np.array(samples[:, 1].tolist()),
                                generator.input_x_3:
                                np.array(samples[:, 2].tolist()),
                                generator.input_x_4:
                                np.array(samples[:, 3].tolist()),
                                generator.neg_index:
                                neg_index,
                                generator.reward:
                                reward
                            }
                            # should be softmax over all, but too computationally expensive
                            _, step, current_loss, positive, negative, score12, score13, l2_loss_g = sess.run(
                                [
                                    generator.gan_updates,
                                    generator.global_step, generator.gan_loss,
                                    generator.positive, generator.negative,
                                    generator.score12, generator.score13,
                                    generator.l2_loss
                                ], feed_dict)

                            line = (
                                "%s: Gen step %d, loss %f l2 %f,%f positive %f negative %f, sample prob [%s, %f], reward [%f, %f]"
                                %
                                (datetime.datetime.now().isoformat(), step,
                                 current_loss, l2_loss_g, l2_loss_d, positive,
                                 negative, np.min(prob), np.max(prob),
                                 np.min(reward), np.max(reward)))
                            if _index % 100 == 0:
                                print(line)
                            loss_log.write(line + "\n")
                            loss_log.flush()

                        evaluation(sess, generator, log, saver,
                                   i * FLAGS.g_epochs_num + g_epoch, 'dev',
                                   True, False)
                        log.flush()
                    baseline = np.mean(baseline_avg)

                # final evaluation
                evaluation(sess, discriminator, log, saver, -1, 'test', False,
                           False, True)
                evaluation(sess, generator, log, saver, -1, 'test', False,
                           False, True)