def __init__(self, vgg_path, style_image, content_shape, content_weight,
                 style_weight, tv_weight, batch_size, device, log_f):
        with tf.device(device):
            self.log_file = log_f
            vgg = vgg_network.VGG(vgg_path)
            self.style_image = style_image
            self.batch_size = batch_size
            self.batch_shape = (batch_size, ) + content_shape

            self.input_batch = tf.placeholder(tf.float32,
                                              shape=self.batch_shape,
                                              name="input_batch")

            self.stylized_image = transform.net(self.input_batch, _vgg=vgg)

            loss_calculator = LossCalculator(vgg, self.stylized_image)

            self.content_loss = loss_calculator.content_loss(
                self.input_batch, self.CONTENT_LAYER,
                content_weight) / self.batch_size

            self.style_loss = loss_calculator.style_loss(
                self.style_image, self.STYLE_LAYERS,
                style_weight) / self.batch_size

            self.total_variation_loss = loss_calculator.tv_loss(
                self.stylized_image, tv_weight) / batch_size

            self.loss = self.content_loss + self.style_loss + self.total_variation_loss
    def __init__(self, vgg_path, content,
                style, content_weight,
                style_weight, tv_weight,
                initial,
                device):
        with tf.device(device):
            self.vgg = vgg_network.VGG(vgg_path)
            self.content = content
            self.style = style
            self.image = self._get_initial_image_or_random(initial)

            loss_calculator = LossCalculator(self.vgg, self.image);

            self.content_loss = loss_calculator.content_loss(content,
                                                             self.CONTENT_LAYER,
                                                             content_weight)

            self.style_loss = loss_calculator.style_loss(style,
                                                         self.STYLE_LAYERS,
                                                         style_weight)

            self.total_variation_loss = loss_calculator.tv_loss(self.image,
                                                                self.content.shape,
                                                                tv_weight)

            self.loss = self.content_loss + self.style_loss + self.total_variation_loss
def main():
    import vgg_network

    parser = build_parser()
    options = parser.parse_args()
    check_opts(options)

    vgg = vgg_network.VGG(options.vgg_path)

    network = options.network_path
    if not os.path.isdir(network):
        parser.error("Network %s does not exist." % network)

    content_image = utils.load_image(options.content)
    reshaped_content_height = (content_image.shape[0] -
                               content_image.shape[0] % 4)
    reshaped_content_width = (content_image.shape[1] -
                              content_image.shape[1] % 4)
    reshaped_content_image = content_image[:reshaped_content_height, :
                                           reshaped_content_width, :]
    reshaped_content_image = np.ndarray.reshape(reshaped_content_image, (1, ) +
                                                reshaped_content_image.shape)

    prediction = ffwd(reshaped_content_image, network, vgg)
    utils.save_image(prediction, options.output_path)
Ejemplo n.º 4
0
def main(_):
    ########################### Set Parameters ###########################
    dataset = 'celebA'
    train_size = np.inf
    num_epochs = 25
    sample_step = 500  # how often we plot and see the reconstruction results
    sample_size = 64  # the size of sample images
    save_step = 500  # how often we save the parameters of the network

    batch_size = 64
    original_size = 108  # original image size
    is_crop = True
    input_size = 64  # image size after crop
    c_dim = 3
    l = 0.5  # used to determine weight of pixel and perceptual loss
    lr = 0.0005  # learning rate

    dcgan_param_path = './dcgan/checkpoint/celebA_64_64/'
    checkpoint_dir = './checkpoint/'
    sample_dir = 'samples'
    vgg_path = './vgg/imagenet-vgg-verydeep-19.mat'
    ############################ Define Model ############################
    print('Building Model...')
    # feed input_img into encoder and get latent_var, and feed latent_var into decoder get output_img
    input_img = tf.placeholder(tf.float32,
                               [batch_size, input_size, input_size, c_dim],
                               name='input_img')
    print('Input shape: ', input_img.get_shape())
    encoder_net, _ = vanilla_encoder(input_img, z_dim=100)
    print('Latent shape: ', encoder_net.outputs.get_shape())
    decoder_net, _ = dcgan_decoder(encoder_net.outputs,
                                   image_size=input_size,
                                   c_dim=c_dim,
                                   batch_size=batch_size)
    print('Output shape: ', decoder_net.outputs.get_shape())
    print('Model successfully built!')
    #################### Define Loss and Training Ops ####################
    # pixel loss: mse
    #loss_pixel = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(decoder_net.outputs, input_img), [1, 2, 3]))
    loss_pixel = tf.nn.l2_loss(input_img - decoder_net.outputs) / (
        batch_size * input_size * input_size * c_dim)

    # computed by the forth convolutional layer of a Image-Net pretrained AlexNet
    # concat_img = tf.concat([input_img, decoder_net.outputs], axis = 0) # concat images along the first dimension
    # alexnet = alexnet_model(concat_img)
    # concat_features = alexnet.conv4
    # print('Shape of concat features: ', concat_features.get_shape())
    # # split the features into two partitions evenly
    # alexnet_features_input_img, alexnet_features_output_img = tf.split(concat_features, num_or_size_splits=2, axis=0)
    # loss_perceptual = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(alexnet_features_output_img, alexnet_features_input_img), [1, 2, 3]))
    vgg = vgg_network.VGG(vgg_path)
    loss_calculator = LossCalculator(vgg, decoder_net.outputs)
    loss_perceptual = loss_calculator.content_loss(
        input_img, content_layer='relu4_3', content_weight=1) / batch_size
    #loss_perceptual = tf.constant(0.0)
    # weighted sum of the two losses
    loss = l * loss_pixel + (1 - l) * loss_perceptual

    train_param = encoder_net.all_params + decoder_net.all_params  # update only the parameters of encoder network
    train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        loss, var_list=train_param)

    ######################### Initialization Step ########################
    sess = tf.InteractiveSession()
    tl.layers.initialize_global_variables(sess)

    # load trained parameters of dcgan_decoder
    print('Loading trained parameters of decoder network...')
    decoder_params = tl.files.load_npz(name=dcgan_param_path + 'net_g.npz')
    tl.files.assign_params(sess, decoder_params, decoder_net)
    print("Having loaded trained parameters of decoder network!")
    decoder_net.print_params()

    # load trained parameters of alexnet for extracting features
    #alexnet.load_initial_weights(sess)

    # Set the path to save parameters
    model_dir = "%s_%s_%s" % (dataset, batch_size, input_size)
    save_dir = os.path.join(checkpoint_dir, model_dir)
    tl.files.exists_or_mkdir(save_dir)
    tl.files.exists_or_mkdir(sample_dir)

    enc_compression_path = os.path.join(save_dir, 'enc_compression.npz')
    dec_compression_path = os.path.join(save_dir, 'dec_compression.npz')

    # get list of all training images' paths
    data_files = glob(os.path.join(
        "./data", dataset, "*.jpg"))  # returns a list of paths for images

    ############################## Train Model ###############################
    iter_counter = 0
    for epoch in range(num_epochs):
        ## shuffle data list
        shuffle(data_files)

        ## update sample files based on shuffled data
        sample_files = data_files[0:sample_size]
        sample = [
            get_image(sample_file,
                      original_size,
                      is_crop=is_crop,
                      resize_w=input_size,
                      is_grayscale=0) for sample_file in sample_files
        ]
        sample_images = np.array(sample).astype(np.float32)
        print("[*] Sample images updated!")

        ## compute the number of batch per epoch
        batch_idxs = min(len(data_files), train_size) // batch_size

        for idx in xrange(0, batch_idxs):
            batch_files = data_files[
                idx * batch_size:(idx + 1) *
                batch_size]  # list, containing path of a batch

            ## get real images
            batch = [
                get_image(batch_file,
                          original_size,
                          is_crop=is_crop,
                          resize_w=input_size,
                          is_grayscale=0) for batch_file in batch_files
            ]
            batch_images = np.array(batch).astype(np.float32)

            start_time = time.time()
            # updates the discriminator
            pix_loss, percep_loss, tot_loss, _ = sess.run(
                [loss_pixel, loss_perceptual, loss, train_op],
                feed_dict={input_img: batch_images})

            print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, pix_loss: %.8f, percep_loss: %.8f, tot_loss: %.8f" \
                    % (epoch, num_epochs, idx, batch_idxs, time.time() - start_time, pix_loss, percep_loss, tot_loss))

            iter_counter += 1
            # Save sample images and their reconstructions
            if np.mod(iter_counter, sample_step) == 0:
                # generate and visualize generated images
                recon_img = sess.run(decoder_net.outputs,
                                     feed_dict={input_img: sample_images})
                print('Shape of input sample is: ', sample_images.shape)
                print('Shape of recon sample is: ', recon_img.shape)
                save_sample = np.concatenate((sample_images, recon_img),
                                             axis=0)
                print('Shape of save_sample is ', save_sample.shape)
                tl.visualize.save_images(
                    save_sample, [16, 8],
                    './{}/train_{:02d}_{:04d}.png'.format(
                        sample_dir, epoch, idx))
                #print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))

            # Save network parameters
            if np.mod(iter_counter, save_step) == 0:
                # save current network parameters
                print("[*] Saving checkpoints...")
                tl.files.save_npz(encoder_net.all_params,
                                  name=enc_compression_path,
                                  sess=sess)
                tl.files.save_npz(decoder_net.all_params,
                                  name=dec_compression_path,
                                  sess=sess)
                print("[*] Saving checkpoints SUCCESS!")