Ejemplo n.º 1
0
def inference_generator_from_raw_tf(res):
    resolutions = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
    featuremaps = [512, 512, 512, 512, 256, 128, 64, 32, 16]
    index = resolutions.index(res)
    inference_resolutions = resolutions[:index + 1]
    inference_featuremaps = featuremaps[:index + 1]

    # prepare variables & construct generator
    image_out_dir = './assets'
    is_training = False
    z_dim = 512
    g_params = {
        'w_dim': 512,
        'n_mapping': 8,
        'resolutions': inference_resolutions,
        'featuremaps': inference_featuremaps,
        'truncation_psi': 0.7,
        'truncation_cutoff': 8,
    }
    z = tf.placeholder(tf.float32, shape=[None, z_dim], name='z')
    alpha = tf.constant(0.0, dtype=tf.float32, shape=[])
    fake_images = generator(z, alpha, g_params, is_training)

    # assign which variables to retore
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    pprint.pprint(var_list)

    # restore tools
    model_dir = '/mnt/vision-nas/moono/trained_models/stylegan-reproduced/{:d}x{:d}'.format(
        res, res)
    model_ckpt = tf.train.latest_checkpoint(os.path.join(model_dir))
    saver = tf.train.Saver(var_list=var_list)

    # set input latent z
    n_output_samples = 4
    rnd = np.random.RandomState(5)
    z_input_np = rnd.randn(n_output_samples, z_dim)

    # generate image with official weights
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, model_ckpt)

        output_batch = sess.run(fake_images, feed_dict={z: z_input_np})

        for ii in range(n_output_samples):
            output = post_process_generator_output(output_batch[ii, :])
            output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

            out_fn = os.path.join(
                image_out_dir,
                'inference-{:d}-{:d}x{:d}.png'.format(ii, res, res))
            cv2.imwrite(out_fn, output)
    return
def test_generator():
    # prepare variables & construct generator
    image_out_dir = './assets'
    is_training = False
    z_dim = 512
    g_params = {
        'w_dim': 512,
        'n_mapping': 8,
        'resolutions': [4, 8, 16, 32, 64, 128, 256, 512, 1024],
        'featuremaps': [512, 512, 512, 512, 256, 128, 64, 32, 16],
        'truncation_psi': 0.7,
        'truncation_cutoff': 8,
    }
    z = tf.placeholder(tf.float32, shape=[None, z_dim], name='z')
    alpha = tf.constant(0.0, dtype=tf.float32, shape=[], name='alpha')
    fake_images = generator(z, alpha, g_params, is_training)

    # assign which variables to retore
    var_mapping = official_code_variables_to_restore()
    pprint.pprint(var_mapping)

    # restore tools
    model_dir = './official-pretrained'
    ckpt_name = 'model.ckpt'
    model_ckpt = os.path.join(model_dir, ckpt_name)
    saver = tf.train.Saver(var_list=var_mapping)

    # set same input status as official's
    rnd = np.random.RandomState(5)
    z_input_np = rnd.randn(1, z_dim)

    # generate image with official weights
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, model_ckpt)

        output = sess.run(fake_images, feed_dict={z: z_input_np})
        print(output.shape)

        output = post_process_generator_output(output)
        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

        out_fn = os.path.join(image_out_dir, 'from-official-weights.png')
        cv2.imwrite(out_fn, output)
    return
Ejemplo n.º 3
0
def model_fn(features, labels, mode, params):
    # parse params
    w_dim = params['w_dim']
    n_mapping = params['n_mapping']
    resolutions = params['resolutions']
    featuremaps = params['featuremaps']
    style_mixing_prob = params['style_mixing_prob']
    truncation_psi = params['truncation_psi']
    truncation_cutoff = params['truncation_cutoff']
    do_train_trans = params['do_train_trans']
    train_trans_images_per_res = params['train_trans_images_per_res']
    batch_size = params['batch_size']

    # additional params
    train_res = resolutions[-1]
    w_ema_decay = params['w_ema_decay']
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    global_step = tf.train.get_or_create_global_step()

    # set generator & discriminator parameters
    g_params = {
        'w_dim': w_dim,
        'n_mapping': n_mapping,
        'resolutions': resolutions,
        'featuremaps': featuremaps,
        'w_ema_decay': w_ema_decay,
        'style_mixing_prob': style_mixing_prob,
        'truncation_psi': truncation_psi,
        'truncation_cutoff': truncation_cutoff,
    }
    d_params = {
        'resolutions': resolutions,
        'featuremaps': featuremaps,
    }

    # additional variables (reuse zero constants)
    zero_constant = tf.constant(0.0, dtype=tf.float32, shape=[])

    # additional variables (for training only)
    train_trans_images_per_res_tensor = tf.constant(
        train_trans_images_per_res,
        dtype=tf.float32,
        shape=[],
        name='train_trans_images_per_res')

    # smooth transition variable
    alpha = tf.get_variable(
        'alpha',
        shape=[],
        dtype=tf.float32,
        initializer=tf.initializers.ones()
        if do_train_trans else tf.initializers.zeros(),
        trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER)

    # determine smooth transition state and compute alpha value
    alpha_const = smooth_transition_state(batch_size, global_step,
                                          train_trans_images_per_res_tensor,
                                          zero_constant)
    alpha_assign_op = tf.assign(alpha, alpha_const)

    # ==================================================================================================================
    # TRAINING
    # ==================================================================================================================
    if mode == tf.estimator.ModeKeys.TRAIN:
        # get training specific parameters
        z_dim = params['z_dim']
        g_learning_rate = params['g_learning_rate']
        d_learning_rate = params['d_learning_rate']

        # get inputs: latent z, real image input
        z = tf.random_normal(shape=[batch_size, z_dim], dtype=tf.float32)
        real_images = features['real_images']

        # get network outputs
        with tf.control_dependencies([alpha_assign_op]):
            # preprocess input images
            real_images.set_shape([None, 3, train_res, train_res])
            real_images = preprocess_fit_train_image(real_images,
                                                     train_res,
                                                     alpha=alpha)

            # create generator output
            fake_images = generator(z, alpha, g_params, is_training=True)

            # get discriminator outputs
            fake_scores = discriminator(fake_images, alpha, d_params)
            real_scores = discriminator(real_images, alpha, d_params)

        # prepare appropriate training vars
        d_vars, g_vars = filter_trainable_variables(train_res)

        # compute loss
        d_loss, g_loss, d_loss_gan, r1_penalty = compute_loss(
            real_images, real_scores, fake_scores)

        # combine loss for tf.estimator architecture
        loss = d_loss + g_loss

        # prepare optimizer & training ops
        d_optimizer = tf.train.AdamOptimizer(g_learning_rate,
                                             beta1=0.0,
                                             beta2=0.99,
                                             epsilon=1e-8)
        g_optimizer = tf.train.AdamOptimizer(d_learning_rate,
                                             beta1=0.0,
                                             beta2=0.99,
                                             epsilon=1e-8)
        d_train_opt = d_optimizer.minimize(d_loss, var_list=d_vars)
        g_train_opt = g_optimizer.minimize(g_loss,
                                           var_list=g_vars,
                                           global_step=global_step)
        train_op = tf.group(d_train_opt, g_train_opt)

        # add summaries
        fake_images_eval = generator(z,
                                     zero_constant,
                                     g_params,
                                     is_training=False)
        summary_real_images = convert_to_rgb_images(real_images)
        summary_fake_images = convert_to_rgb_images(fake_images)
        summary_fake_images_eval = convert_to_rgb_images(fake_images_eval)
        tf.summary.scalar('alpha', alpha)
        tf.summary.scalar('d_loss_gan', d_loss_gan)
        tf.summary.scalar('r1_penalty', r1_penalty)
        tf.summary.scalar('d_loss', d_loss)
        tf.summary.scalar('g_loss', g_loss)
        tf.summary.image('real_images', summary_real_images[:5], max_outputs=5)
        tf.summary.image('fake_images', summary_fake_images[:5], max_outputs=5)
        tf.summary.image('fake_images_eval',
                         summary_fake_images_eval[:5],
                         max_outputs=5)
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops={},
                                          predictions={})

    # ==================================================================================================================
    # EVALUATION
    # ==================================================================================================================
    if mode == tf.estimator.ModeKeys.EVAL:
        # tf.summary.image not working on eval mode?
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=zero_constant,
                                          eval_metric_ops={})

    # ==================================================================================================================
    # PREDICTION
    # ==================================================================================================================
    if mode == tf.estimator.ModeKeys.PREDICT:
        # get input latent z
        z = features['z']

        # create generator output for evalutation & prediction
        fake_images_eval = generator(z,
                                     zero_constant,
                                     g_params,
                                     is_training=False)

        predictions = {'fake_images': fake_images_eval}
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)