示例#1
0
def train():
    """
    """
    print('training')

    model = Model()

    reader = ore.RandomReader(ore.DATASET_KAGGLE_MNIST_TRAINING)

    while True:
        images, labels, is_new_batch = next_batch(reader, 128)

        loss, step = model.train(images, labels)

        print('loss: {}'.format(loss))

        if step % 1000 == 0:
            model.save_checkpoint()
示例#2
0
def train():
    """
    """
    reader = ore.RandomReader(ore.DATASET_LSUN_BEDROOM_TRAINING)

    # tensorflow
    checkpoint_source_path = tf.train.latest_checkpoint(
        FLAGS.checkpoints_dir_path)
    checkpoint_target_path = os.path.join(FLAGS.checkpoints_dir_path,
                                          'model.ckpt')

    gan_graph = build_xwgan()
    summaries = build_summaries(gan_graph)

    reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)

    with tf.Session() as session:
        if checkpoint_source_path is None:
            session.run(tf.global_variables_initializer())
        else:
            tf.train.Saver().restore(session, checkpoint_source_path)

        # give up overlapped old data
        global_step = session.run(gan_graph['global_step'])

        reporter.add_session_log(tf.SessionLog(status=tf.SessionLog.START),
                                 global_step=global_step)

        while True:
            for _ in range(5):
                real_sources = next_real_batch(reader)
                fake_sources = next_fake_batch()

                fetches = [
                    gan_graph['discriminator_loss'],
                    gan_graph['discriminator_trainer'],
                    summaries['discriminator_loss_summary']
                ]

                feeds = {
                    gan_graph['seed']: fake_sources,
                    gan_graph['real']: real_sources,
                }

                returns = session.run(fetches, feed_dict=feeds)

                d_loss_summary = returns[2]

            reporter.add_summary(d_loss_summary, global_step)

            #
            fake_sources = next_fake_batch()

            log_fakes = (global_step % 500 == 0)

            fetches = [
                gan_graph['global_step'],
                gan_graph['generator_loss'],
                gan_graph['generator_trainer'],
                summaries['generator_loss_summary'],
            ]

            feeds = {gan_graph['seed']: fake_sources}

            if log_fakes:
                fetches.append(summaries['generator_fake_summary'])

            returns = session.run(fetches, feed_dict=feeds)

            global_step = returns[0]
            g_loss_summary = returns[3]

            reporter.add_summary(g_loss_summary, global_step)

            if log_fakes:
                reporter.add_summary(returns[4], global_step)

            if global_step % 100 == 0:
                print('[{}]: {}'.format(global_step, returns[1]))

            if global_step % 500 == 0:
                tf.train.Saver().save(session,
                                      checkpoint_target_path,
                                      global_step=gan_graph['global_step'])
示例#3
0
def train():
    """
    """
    if FLAGS.use_lsun:
        reader = ore.RandomReader(ore.DATASET_LSUN_BEDROOM_TRAINING)
    else:
        reader = ore.RandomReader(ore.DATASET_MNIST_TRAINING)

    # tensorflow
    checkpoint_source_path = tf.train.latest_checkpoint(
        FLAGS.checkpoints_dir_path)
    checkpoint_target_path = os.path.join(FLAGS.checkpoints_dir_path,
                                          'model.ckpt')

    gan_graph = build_dcgan()
    summaries = build_summaries(gan_graph)

    reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)

    with tf.Session() as session:
        if checkpoint_source_path is None:
            session.run(tf.global_variables_initializer())
        else:
            tf.train.Saver().restore(session, checkpoint_source_path)

        # give up overlapped old data
        global_step = session.run(gan_graph['global_step'])

        reporter.add_session_log(tf.SessionLog(status=tf.SessionLog.START),
                                 global_step=global_step)

        while True:
            real_sources = next_real_batch(reader)
            fake_sources = next_fake_batch()

            fetches = {
                'discriminator_loss': gan_graph['discriminator_loss'],
                'discriminator_trainer': gan_graph['discriminator_trainer'],
                'summary': summaries['summary_discriminator']
            }

            feeds = {
                gan_graph['seed']: fake_sources,
                gan_graph['real']: real_sources,
            }

            fetched = session.run(fetches, feed_dict=feeds)

            reporter.add_summary(fetched['summary'], global_step)

            #
            fetches = {
                'global_step': gan_graph['global_step'],
                'generator_loss': gan_graph['generator_loss'],
                'generator_trainer': gan_graph['generator_trainer'],
            }

            feeds = {gan_graph['seed']: fake_sources}

            if global_step % 500 == 0:
                fetches['summary'] = summaries['summary_generator_plus']
            else:
                fetches['summary'] = summaries['summary_generator']

            fetched = session.run(fetches, feed_dict=feeds)

            global_step = fetched['global_step']

            reporter.add_summary(fetched['summary'], global_step)

            if global_step % 100 == 0:
                print('[{}]'.format(global_step))

            if global_step % 500 == 0:
                tf.train.Saver().save(session,
                                      checkpoint_target_path,
                                      global_step=gan_graph['global_step'])