コード例 #1
0
def main(_):
    if not FLAGS.checkpoint_dir:
        raise ValueError(
            'You must supply the checkpoint_dir with --checkpoint_dir')

    # checkpoint_dir in each the combination of hyper-parameters
    checkpoint_dir = configuration.hyperparameters_dir(FLAGS.checkpoint_dir)

    if not tf.gfile.IsDirectory(checkpoint_dir):
        raise ValueError('checkpoint_dir must be folder path')

    with tf.Graph().as_default():
        start_time = time.time()

        # Build the DiscoGAN model.
        model = disco.DiscoGAN(mode="translate")
        model.build()

        # Restore the moving average version of the learned variables for image translate.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()

        # Set up the Saver for saving and restoring model checkpoints.
        saver = tf.train.Saver(variables_to_restore)

        # Read dataset
        data_A, data_B = data.get_data()
        data_size = min(len(data_A), len(data_B))
        A_path, B_path = data.get_batch(FLAGS.batch_size, data_A, data_B, 0, 0,
                                        0, data_size)
        images_A = data.read_images(A_path, None, FLAGS.image_size)
        images_B = data.read_images(B_path, None, FLAGS.image_size)

        # Translate image for all checkpoints
        if FLAGS.is_all_checkpoints:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            for checkpoint_path in reversed(ckpt.all_model_checkpoint_paths):
                if not os.path.exists(
                        os.path.join(checkpoint_path +
                                     '.data-00000-of-00001')):
                    raise ValueError("No checkpoint files found in: %s" %
                                     checkpoint_path)
                print(checkpoint_path)

                A2B, B2A, A2B2A, B2A2B = run_generator_once(
                    saver, checkpoint_path, model, images_A, images_B)

                squared_A = make_squared_image(np.copy(images_A))
                squared_B = make_squared_image(np.copy(images_B))
                squared_A2B = make_squared_image(A2B)
                squared_B2A = make_squared_image(B2A)
                squared_A2B2A = make_squared_image(A2B2A)
                squared_B2A2B = make_squared_image(B2A2B)

                domain_A_images = merge_images(squared_A, squared_A2B,
                                               squared_A2B2A)
                domain_B_images = merge_images(squared_B, squared_B2A,
                                               squared_B2A2B)

                checkpoint_step = int(
                    os.path.basename(checkpoint_path).split('-')[1])
                ImageWrite(domain_A_images, 'domain_A2B', checkpoint_step)
                ImageWrite(domain_B_images, 'domain_B2A', checkpoint_step)

        # Translate image for the last checkpoint or a specific checkpoint
        else:
            if FLAGS.checkpoint_step == -1:
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            else:
                checkpoint_path = os.path.join(
                    checkpoint_dir, 'model.ckpt-%d' % FLAGS.checkpoint_step)

            if not os.path.exists(
                    os.path.join(checkpoint_path + '.data-00000-of-00001')):
                raise ValueError("No checkpoint files found in: %s" %
                                 checkpoint_path)

            A2B, B2A, A2B2A, B2A2B = run_generator_once(
                saver, checkpoint_path, model, images_A, images_B)

            squared_A = make_squared_image(images_A)
            squared_B = make_squared_image(images_B)
            squared_A2B = make_squared_image(A2B)
            squared_B2A = make_squared_image(B2A)
            squared_A2B2A = make_squared_image(A2B2A)
            squared_B2A2B = make_squared_image(B2A2B)

            domain_A_images = merge_images(squared_A, squared_A2B,
                                           squared_A2B2A)
            domain_B_images = merge_images(squared_B, squared_B2A,
                                           squared_B2A2B)

            checkpoint_step = int(
                os.path.basename(checkpoint_path).split('-')[1])
            ImageWrite(domain_A_images, 'domain_A2B', checkpoint_step)
            ImageWrite(domain_B_images, 'domain_B2A', checkpoint_step)

        print('complete image translate...')
コード例 #2
0
def main(_):

    # train_dir path in each the combination of hyper-parameters
    train_dir = configuration.hyperparameters_dir(FLAGS.train_dir)

    if tf.gfile.Exists(train_dir):
        raise ValueError('This folder already exists.')
    tf.gfile.MakeDirs(train_dir)

    with tf.Graph().as_default():

        # Build the model.
        model = disco.DiscoGAN(mode="train")
        model.build()

        # Create global step
        global_step = slim.create_global_step()

        # No decay learning rate
        learning_rate = tf.constant(FLAGS.initial_learning_rate)
        tf.summary.scalar('learning_rate', learning_rate)

        # Create an optimizer that performs gradient descent for Discriminator.
        opt_D = tf.train.AdamOptimizer(learning_rate,
                                       beta1=FLAGS.adam_beta1,
                                       beta2=FLAGS.adam_beta2,
                                       epsilon=FLAGS.adam_epsilon)

        # Create an optimizer that performs gradient descent for Discriminator.
        opt_G = tf.train.AdamOptimizer(learning_rate,
                                       beta1=FLAGS.adam_beta1,
                                       beta2=FLAGS.adam_beta2,
                                       epsilon=FLAGS.adam_epsilon)

        # Minimize optimizer
        opt_op_D = opt_D.minimize(model.loss_Discriminator,
                                  global_step=global_step,
                                  var_list=model.D_vars)
        opt_op_G = opt_G.minimize(model.loss_Generator,
                                  global_step=global_step,
                                  var_list=model.G_vars)

        # Track the moving averages of all trainable variables.
        # Note that we maintain a "double-average" of the BatchNormalization
        # global statistics. This is more complicated then need be but we employ
        # this for backward-compatibility with our previous models.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY, global_step)

        # Another possibility is to use tf.slim.get_variables().
        variables_to_average = (tf.trainable_variables() +
                                tf.moving_average_variables())
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Batch normalization update
        batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        batchnorm_updates_op = tf.group(*batchnorm_updates)

        train_op = tf.group(opt_op_D, opt_op_G, variables_averages_op,
                            batchnorm_updates_op)

        # Add dependency to compute batchnorm_updates.
        with tf.control_dependencies(
            [variables_averages_op, batchnorm_updates_op]):
            opt_op_D
            opt_op_G

        # Set up the Saver for saving and restoring model checkpoints.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)

        # Start running operations on the Graph.
        with tf.Session() as sess:
            # Build an initialization operation to run below.
            init = tf.global_variables_initializer()
            sess.run(init)

            # Start the queue runners.
            tf.train.start_queue_runners(sess=sess)

            # Create a summary writer, add the 'graph' to the event file.
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)

            # Retain the summaries and build the summary operation
            summary_op = tf.summary.merge_all()

            # Read dataset
            data_A, data_B = data.get_data()
            data_size = min(len(data_A), len(data_B))

            pre_epochs = 0.0

            for step in range(FLAGS.max_steps + 1):
                start_time = time.time()

                epochs = step * FLAGS.batch_size / data_size
                A_path, B_path = data.get_batch(FLAGS.batch_size, data_A,
                                                data_B, pre_epochs, epochs,
                                                step, data_size)

                images_A = data.read_images(A_path, None, FLAGS.image_size)
                images_B = data.read_images(B_path, None, FLAGS.image_size)

                feed_dict = {
                    model.images_A: images_A,
                    model.images_B: images_B
                }
                _, loss_D, loss_G = sess.run(
                    [train_op, model.loss_Discriminator, model.loss_Generator],
                    feed_dict=feed_dict)

                pre_epochs = epochs
                duration = time.time() - start_time

                #if step % 10 == 0:
                examples_per_sec = FLAGS.batch_size / float(duration)
                print(
                    "Epochs: %.2f step: %d  loss_D: %f loss_G: %f (%.1f examples/sec; %.3f sec/batch)"
                    %
                    (epochs, step, loss_D, loss_G, examples_per_sec, duration))

                if step % 200 == 0:
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)

                # Save the model checkpoint periodically.
                if step % FLAGS.save_steps == 0:
                    checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

        print('complete training...')
コード例 #3
0
def main(_):

  with tf.Graph().as_default():

    # Build the model.
    model = disco.DiscoGAN(mode="train")
    model.build()

    # No decay learning rate
    learning_rate = tf.constant(FLAGS.initial_learning_rate)
    tf.summary.scalar('learning_rate', learning_rate)

    # Create an optimizer that performs gradient descent for Discriminator.
    opt_D = tf.train.AdamOptimizer(
                learning_rate,
                beta1=FLAGS.adam_beta1,
                beta2=FLAGS.adam_beta2,
                epsilon=FLAGS.adam_epsilon)

    # Create an optimizer that performs gradient descent for Discriminator.
    opt_G = tf.train.AdamOptimizer(
                learning_rate,
                beta1=FLAGS.adam_beta1,
                beta2=FLAGS.adam_beta2,
                epsilon=FLAGS.adam_epsilon)

    # Minimize optimizer
    opt_D_op = opt_D.minimize(model.loss_Discriminator,
                              global_step=model.global_step,
                              var_list=model.D_vars)
    opt_G_op = opt_G.minimize(model.loss_Generator,
                              var_list=model.G_vars)

    # Track the moving averages of all trainable variables.
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.MOVING_AVERAGE_DECAY, model.global_step)
    variables_to_average = tf.trainable_variables()
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Batch normalization update
    batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    batchnorm_updates_op = tf.group(*batchnorm_updates)

    train_op = tf.group(opt_D_op, opt_G_op, variables_averages_op,
                        batchnorm_updates_op)


    # Set up the Saver for saving and restoring model checkpoints.
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)

    # Build the summary operation
    summary_op = tf.summary.merge_all()

    # train_dir path in each the combination of hyper-parameters
    train_dir = configuration.hyperparameters_dir(FLAGS.train_dir)

    # Training with tf.train.Supervisor.
    sv = tf.train.Supervisor(logdir=train_dir,
                             summary_op=None,     # Do not run the summary services
                             saver=saver,
                             save_model_secs=0,   # Do not run the save_model services
                             init_fn=None)        # Not use pre-trained model
    # Start running operations on the Graph.
    with sv.managed_session() as sess:
      tf.logging.info('Start Session')

      # Start the queue runners.
      sv.start_queue_runners(sess=sess)
      tf.logging.info('Starting Queues.')

      # Read dataset
      data_A, data_B = data.get_data()
      data_size = min( len(data_A), len(data_B) )

      pre_epochs = 0.0

      for step in range(FLAGS.max_steps):
        start_time = time.time()
        if sv.should_stop():
          break

        epochs = step * FLAGS.batch_size / data_size
        A_path, B_path = data.get_batch(FLAGS.batch_size, data_A, data_B, pre_epochs, epochs, step, data_size)

        images_A = data.read_images(A_path, None, FLAGS.image_size)
        images_B = data.read_images(B_path, None, FLAGS.image_size)

        feed_dict = {model.images_A: images_A,
                     model.images_B: images_B}
        _, _global_step, loss_D, loss_G = sess.run([train_op,
                                                    sv.global_step,
                                                    model.loss_Discriminator,
                                                    model.loss_Generator],
                                                    feed_dict=feed_dict)

        pre_epochs = epochs
        duration = time.time() - start_time

        # Monitoring training situation in console.
        if _global_step % 10 == 0:
          examples_per_sec = FLAGS.batch_size / float(duration)
          print("Epochs: %.2f global_step: %d  loss_D: %f loss_G: %f (%.1f examples/sec; %.3f sec/batch)"
                    % (epochs, _global_step, loss_D, loss_G, examples_per_sec, duration))

        # Save the model summaries periodically.
        if _global_step % 200 == 0:
          summary_str = sess.run(summary_op, feed_dict=feed_dict)
          sv.summary_computed(sess, summary_str)

        # Save the model checkpoint periodically.
        if _global_step % FLAGS.save_steps == 0:
          tf.logging.info('Saving model with global step %d to disk.' % _global_step)
          sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

    tf.logging.info('complete training...')
コード例 #4
0
ファイル: generate.py プロジェクト: ilguyi/dcgan.tensorflow
def main(_):
    if not FLAGS.checkpoint_dir:
        raise ValueError(
            'You must supply the checkpoint_dir with --checkpoint_dir')

    # checkpoint_dir in each the combination of hyper-parameters
    checkpoint_dir = configuration.hyperparameters_dir(FLAGS.checkpoint_dir)

    if not tf.gfile.IsDirectory(checkpoint_dir):
        raise ValueError('checkpoint_dir must be folder path')

    with tf.Graph().as_default():
        # Build the generative model.
        model = dcgan.DeepConvGANModel(mode="generate")
        model.build()

        # Restore the moving average version of the learned variables for image translate.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()

        # Set up the Saver for saving and restoring model checkpoints.
        saver = tf.train.Saver(variables_to_restore)

        if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
            raise ValueError("checkpoint_dir must be folder path")

        # Generate images for all checkpoints
        if FLAGS.make_gif:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

            # Set fixed random vectors
            np.random.seed(FLAGS.seed)
            random_z = np.random.uniform(-1, 1, [FLAGS.batch_size, 1, 1, 100])

            generated_gifs = []
            for checkpoint_path in ckpt.all_model_checkpoint_paths:
                if not os.path.exists(
                        os.path.join(checkpoint_path +
                                     '.data-00000-of-00001')):
                    raise ValueError("No checkpoint files found in: %s" %
                                     checkpoint_path)
                print(checkpoint_path)

                generated_images = run_generator_once(saver, checkpoint_path,
                                                      model, random_z)
                squared_images = make_squared_image(generated_images)
                checkpoint_step = int(
                    checkpoint_path.split('/')[-1].split('-')[-1])
                generated_gifs.append((squared_images, checkpoint_step))

            GIFWrite(generated_gifs)

        else:
            if FLAGS.checkpoint_step == -1:
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            else:
                checkpoint_path = os.path.join(
                    checkpoint_dir, 'model.ckpt-%d' % FLAGS.checkpoint_step)

            if not os.path.exists(
                    os.path.join(checkpoint_path + '.data-00000-of-00001')):
                raise ValueError("No checkpoint file found in: %s" %
                                 checkpoint_path)

            # Set fixed random vectors
            np.random.seed(FLAGS.seed)
            random_z = np.random.uniform(-1, 1, [FLAGS.batch_size, 1, 1, 100])

            generated_images = run_generator_once(saver, checkpoint_path,
                                                  model, random_z)
            squared_images = make_squared_image(generated_images)

            checkpoint_step = int(
                checkpoint_path.split('/')[-1].split('-')[-1])
            ImageWrite(squared_images, checkpoint_step)

        print('complete generating image...')
コード例 #5
0
def main(_):

    with tf.Graph().as_default():

        # Build the model.
        model = dcgan.DeepConvGANModel(mode="train")
        model.build()

        # Create global step
        #global_step = tf.train.create_global_step()

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (FLAGS.num_examples / FLAGS.batch_size)
        print("num_batches_per_epoch: %f" % num_batches_per_epoch)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)

        # Decay the learning rate exponentially based on the number of steps.
        learning_rate = tf.constant(FLAGS.initial_learning_rate)

        def _learning_rate_decay_fn(learning_rate, global_step):
            return tf.train.exponential_decay(
                learning_rate,
                global_step,
                decay_steps=decay_steps,
                decay_rate=FLAGS.learning_rate_decay_factor,
                staircase=True)

        tf.summary.scalar('learning_rate', learning_rate)

        # Create an optimizer that performs gradient descent for Discriminator.
        opt_D = SetOptimizer(learning_rate, FLAGS.optimizer)
        # Create an optimizer that performs gradient descent for Generator.
        opt_G = SetOptimizer(learning_rate, FLAGS.optimizer)

        # Minimize optimizer
        # one training step is defined by both optimizers run once.
        #    opt_D_op = opt_D.minimize(model.loss_Discriminator,
        #                              var_list=model.D_vars)
        #    opt_G_op = opt_G.minimize(model.loss_Generator,
        #                              global_step=model.global_step,
        #                              var_list=model.G_vars)

        # Track the moving averages of all trainable variables.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY, model.global_step)
        variables_to_average = tf.trainable_variables()
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Batch normalization update
        batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        batchnorm_updates_op = tf.group(*batchnorm_updates)

        #    train_op = tf.group(opt_D_op, opt_G_op, variables_averages_op,
        #                        batchnorm_updates_op)

        # Add dependency to compute batchnorm_updates.
        with tf.control_dependencies(
            [variables_averages_op, batchnorm_updates_op]):
            # Minimize optimizer
            opt_D_op = opt_D.minimize(model.loss_Discriminator,
                                      var_list=model.D_vars)
            opt_G_op = opt_G.minimize(model.loss_Generator,
                                      global_step=model.global_step,
                                      var_list=model.G_vars)

        # Set up the Saver for saving and restoring model checkpoints.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)

        # Build the summary operation
        summary_op = tf.summary.merge_all()

        # train_dir path in each the combination of hyper-parameters
        train_dir = configuration.hyperparameters_dir(FLAGS.train_dir)

        # Training with tf.train.Supervisor.
        sv = tf.train.Supervisor(
            logdir=train_dir,
            summary_op=None,  # Do not run the summary services
            saver=saver,
            save_model_secs=0,  # Do not run the save_model services
            init_fn=None)  # Not use pre-trained model
        # Start running operations on the Graph.
        with sv.managed_session() as sess:
            tf.logging.info('Start Session.')

            # Start the queue runners.
            sv.start_queue_runners(sess=sess)
            tf.logging.info('Starting Queues.')

            # Run a model
            for epoch in range(FLAGS.max_epochs):
                for j in range(int(num_batches_per_epoch)):
                    start_time = time.time()
                    if sv.should_stop():
                        break

                    for _ in range(FLAGS.k):
                        _, loss_D = sess.run(
                            [opt_D_op, model.loss_Discriminator])
                    _, _global_step, loss_G = sess.run(
                        [opt_G_op, sv.global_step, model.loss_Generator])

                    epochs = epoch + j / num_batches_per_epoch
                    duration = time.time() - start_time

                    # Monitoring training situation in console.
                    if _global_step % 10 == 0:
                        examples_per_sec = FLAGS.batch_size / float(duration)
                        print(
                            "Epochs: %.3f global step: %d  loss_D: %f loss_G: %f (%.1f examples/sec; %.3f sec/batch)"
                            % (epochs, _global_step, loss_D, loss_G,
                               examples_per_sec, duration))

                    # Save the model summaries periodically.
                    if _global_step % 200 == 0:
                        summary_str = sess.run(summary_op)
                        sv.summary_computed(sess, summary_str)

                    # Save the model checkpoint periodically.
                    if epoch % FLAGS.save_epochs == 0 and j == 0:
                        tf.logging.info(
                            'Saving model with global step %d (= %d epoch) to disk.'
                            % (_global_step, epoch))
                        sv.saver.save(sess,
                                      sv.save_path,
                                      global_step=sv.global_step)

        tf.logging.info('complete training...')