Exemplo n.º 1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    dataset = mscoco(FLAGS)
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 1 / 10
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        filter_sizes = [
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 24, dataset.max_words
        ]
        num_filters = [
            100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160, 160
        ]
        num_filters_total = sum(num_filters)
        info = {
            'num_classes': 3,
            'filter_sizes': filter_sizes,
            'num_filters': num_filters,
            'num_filters_total': num_filters_total,
            'l2_reg_lambda': 0.2
        }
        if FLAGS.G_is_pretrain:
            G_pretrained_model = G_pretrained(sess, dataset, conf=FLAGS)
            if FLAGS.is_train:
                G_pretrained_model.train()
            G_pretrained_model.evaluate(
                'test',
                0,
            )
        if FLAGS.D_is_pretrain:
            negative_dataset = mscoco_negative(dataset, FLAGS)
            D_pretrained_model = D_pretrained(sess,
                                              dataset,
                                              negative_dataset,
                                              info,
                                              conf=FLAGS)
            D_pretrained_model.train()
        if FLAGS.is_train:
            model = SeqGAN(sess, dataset, info, conf=FLAGS)
            model.train()
Exemplo n.º 2
0
    import logging
    root = logging.getLogger()
    root.setLevel(logging.DEBUG)

    dictionary, rev_dict = utils.get_dictionary(args.text)
    num_classes = len(dictionary)

    iterator = utils.tokenize(args.text,
                              dictionary,
                              batch_size=args.batch_size,
                              seq_len=args.seq_len)

    sess = tf.Session()
    model = SeqGAN(sess,
                   num_classes,
                   logdir=args.logdir,
                   learn_phase=args.learn_phase,
                   only_cpu=args.only_cpu)
    model.build()
    model.load(ignore_missing=True)

    for epoch in xrange(1, args.num_epochs + 1):
        for step in xrange(1, args.num_steps + 1):
            logging.info('epoch %d, step %d', epoch, step)
            model.train_batch(iterator.next())

        # Generates a sample from the model.
        g = model.generate(1000)
        print(utils.detokenize(g, rev_dict))

        # Saves the model to the logdir.
Exemplo n.º 3
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Sample from a trained SeqGAN model.')
    parser.add_argument('sample_len', metavar='N', type=int,
                        help='length of sample to generate')
    parser.add_argument('-t', '--dictionary', default='dictionary.pkl',
                        type=str, help='path to dictionary file')
    parser.add_argument('-d', '--logdir', default='model/', type=str,
                        help='directory of the trained model')
    parser.add_argument('-c', '--only_cpu', default=True, action='store_true',
                        help='if set, only build weights on cpu')
    args = parser.parse_args()

    if not os.path.exists(args.dictionary):
        raise ValueError('No dictionary file found: "%s". To build it, '
                         'run train.py' % args.dictionary)

    _, rev_dict = utils.get_dictionary(None, dfile=args.dictionary)
    num_classes = len(rev_dict)

    sess = tf.Session()
    model = SeqGAN(sess,
                   num_classes,
                   logdir=args.logdir,
                   only_cpu=args.only_cpu)
    model.build()
    model.load(ignore_missing=True)

    g = model.generate(args.sample_len)
    print('Generated text:', utils.detokenize(g, rev_dict))
Exemplo n.º 4
0
    args = parser.parse_args()

    # Turns on logging.
    import logging
    root = logging.getLogger()
    root.setLevel(logging.DEBUG)

    dictionary, rev_dict = utils.get_dictionary(args.text)
    num_classes = len(dictionary)

    iterator = utils.tokenize(args.text,
                              dictionary,
                              batch_size=args.batch_size,
                              seq_len=args.seq_len)

    sess = tf.Session()
    model = SeqGAN(sess, num_classes, only_cpu=args.only_cpu)
    model.build()

    for epoch in xrange(args.num_epochs):
        for step in xrange(args.num_steps):
            logging.info('epoch %d, step %d', epoch, step)
            model.train_batch(iterator.next())

        # Generates a sample from the model.
        g = model.generate(100)
        logging.info('Epoch %d: "%s"', epoch, utils.detokenize(g, rev_dict))

        # Saves the model to the logdir.
        model.save()
Exemplo n.º 5
0
import argparse
import chainer.serializers
import os
import sys
import pickle
import numpy as np
from model import SeqGAN
import time
import datetime
import multiprocessing as mp

pool = mp.Pool()

generator = SeqGAN(vocab_size=3000, emb_dim=128, hidden_dim=128,
                   sequence_length=40, start_token=0, lstm_layer=2
                   ).to_gpu()

batch_size = 10000


def progress_report(count, start_time, batch_size):
    duration = time.time() - start_time
    throughput = count * batch_size / duration
    sys.stderr.write(
        '\rtrain {} updates ({} samples) time: {} ({:.2f} samples/sec)'
            .format(count, count * batch_size,
                    str(datetime.timedelta(seconds=duration)).split('.')[0], throughput))

negative = []