def build_model(self):
        data_path = self.args.train_data_path

        imgs = load_data.get_loader(data_path, self.batch_size, self.img_size)

        style_imgs = load_style_img(self.args.style_data_path)

        with slim.arg_scope(model.arg_scope()):
            gen_img, variables = model.gen_net(imgs,
                                               reuse=False,
                                               name='transform')

            with slim.arg_scope(vgg.vgg_arg_scope()):
                gen_img_processed = [
                    load_data.img_process(image, True) for image in tf.unstack(
                        gen_img, axis=0, num=self.batch_size)
                ]

                f1, f2, f3, f4, exclude = vgg.vgg_16(
                    tf.concat([gen_img_processed, imgs, style_imgs], axis=0))

                gen_f, img_f, _ = tf.split(f3, 3, 0)
                content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(
                    tf.size(gen_f))

                style_loss = model.styleloss(f1, f2, f3, f4)

                # load vgg model
                vgg_model_path = self.args.vgg_model
                vgg_vars = slim.get_variables_to_restore(include=['vgg_16'],
                                                         exclude=exclude)
                # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
                init_fn = slim.assign_from_checkpoint_fn(
                    vgg_model_path, vgg_vars)
                init_fn(self.sess)
                # tf.initialize_variables(var_list=vgg_init_var)
                print('vgg s weights load done')

            self.gen_img = gen_img

            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)

            self.content_loss = content_loss
            self.style_loss = style_loss * self.args.style_w
            self.loss = self.content_loss + self.style_loss
            self.opt = tf.train.AdamOptimizer(0.0001).minimize(
                self.loss, global_step=self.global_step, var_list=variables)

        all_var = tf.global_variables()
        # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name]
        init_var = [v for v in all_var if 'vgg_16' not in v.name]
        init = tf.variables_initializer(var_list=init_var)
        self.sess.run(init)

        self.save = tf.train.Saver(var_list=variables)
Example #2
0
    def build_network(self, inputs, fg_inputs, targets, trainable=True):
        def discrim_conv(batch_input, out_channels, stride):
            padded_input = tf.pad(batch_input,
                                  [[0, 0], [1, 1], [1, 1], [0, 0]],
                                  mode="CONSTANT")
            return tf.layers.conv2d(
                padded_input,
                out_channels,
                kernel_size=4,
                strides=(stride, stride),
                padding="valid",
                kernel_initializer=tf.random_normal_initializer(0, 0.02))

        def gen_conv(batch_input, out_channels):
            # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
            initializer = tf.random_normal_initializer(0, 0.02)
            if self.separable_conv:
                return tf.layers.separable_conv2d(
                    batch_input,
                    out_channels,
                    kernel_size=4,
                    strides=(2, 2),
                    padding="same",
                    depthwise_initializer=initializer,
                    pointwise_initializer=initializer)
            else:
                return tf.layers.conv2d(batch_input,
                                        out_channels,
                                        kernel_size=4,
                                        strides=(2, 2),
                                        padding="same",
                                        kernel_initializer=initializer)

        def gen_deconv(batch_input, out_channels):
            # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
            initializer = tf.random_normal_initializer(0, 0.02)
            if self.separable_conv:
                _b, h, w, _c = batch_input.shape
                resized_input = tf.image.resize_images(
                    batch_input, [h * 2, w * 2],
                    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                return tf.layers.separable_conv2d(
                    resized_input,
                    out_channels,
                    kernel_size=4,
                    strides=(1, 1),
                    padding="same",
                    depthwise_initializer=initializer,
                    pointwise_initializer=initializer)
            else:
                return tf.layers.conv2d_transpose(
                    batch_input,
                    out_channels,
                    kernel_size=4,
                    strides=(2, 2),
                    padding="same",
                    kernel_initializer=initializer)

        def lrelu(x, a):
            with tf.name_scope("lrelu"):
                # adding these together creates the leak part and linear part
                # then cancels them out by subtracting/adding an absolute value term
                # leak: a*x/2 - a*abs(x)/2
                # linear: x/2 + abs(x)/2

                # this block looks like it has 2 inputs on the graph unless we do this
                x = tf.identity(x)
                return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)

        def batchnorm(inputs):
            return tf.layers.batch_normalization(
                inputs,
                axis=3,
                epsilon=1e-5,
                momentum=0.1,
                training=True,
                gamma_initializer=tf.random_normal_initializer(1.0, 0.02))

        def create_discriminator(discrim_inputs, discrim_targets):
            n_layers = 3
            layers = []

            # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
            input = tf.concat([discrim_inputs, discrim_targets], axis=3)

            # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
            with tf.variable_scope("layer_1"):
                convolved = discrim_conv(input, self.ndf, stride=2)
                rectified = lrelu(convolved, 0.2)
                layers.append(rectified)

            # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
            # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
            # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
            for i in range(n_layers):
                with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                    out_channels = self.ndf * min(2**(i + 1), 8)
                    stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
                    convolved = discrim_conv(layers[-1],
                                             out_channels,
                                             stride=stride)
                    normalized = batchnorm(convolved)
                    rectified = lrelu(normalized, 0.2)
                    layers.append(rectified)

            # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
            with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                convolved = discrim_conv(rectified, out_channels=1, stride=1)
                output = tf.sigmoid(convolved)
                layers.append(output)

            return layers[-1]

        def create_target_discriminator(discrim_inputs):
            n_layers = 3
            layers = []

            # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
            with tf.variable_scope("layer_1"):
                convolved = discrim_conv(discrim_inputs, self.ndf, stride=2)
                rectified = lrelu(convolved, 0.2)
                layers.append(rectified)

            # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
            # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
            # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
            for i in range(n_layers):
                with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                    out_channels = self.ndf * min(2**(i + 1), 8)
                    stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
                    convolved = discrim_conv(layers[-1],
                                             out_channels,
                                             stride=stride)
                    normalized = batchnorm(convolved)
                    rectified = lrelu(normalized, 0.2)
                    layers.append(rectified)

            # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
            with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                convolved = discrim_conv(rectified, out_channels=1, stride=1)
                output = tf.sigmoid(convolved)
                layers.append(output)

            return layers[-1]

        def create_generator(generator_inputs, generator_fg_inputs,
                             generator_outputs_channels):
            layers = []

            # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
            with tf.variable_scope("encoder_1"):
                output = gen_conv(generator_inputs, self.ngf)
                layers.append(output)

            layer_specs = [
                self.ngf *
                2,  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
                self.ngf *
                2,  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
                self.ngf *
                4,  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
            ]

            for out_channels in layer_specs:
                with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
                    rectified = lrelu(layers[-1], 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                    convolved = gen_conv(rectified, out_channels)
                    output = batchnorm(convolved)
                    layers.append(output)

            fg_layers = []
            # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
            with tf.variable_scope("encoder_fg_1"):
                output = gen_conv(generator_fg_inputs, self.ngf)
                fg_layers.append(output)

            layer_specs = [
                self.ngf *
                2,  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
                self.ngf *
                2,  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
                self.ngf *
                4,  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
            ]

            for out_channels in layer_specs:
                with tf.variable_scope("encoder_fg_%d" % (len(fg_layers) + 1)):
                    rectified = lrelu(fg_layers[-1], 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                    convolved = gen_conv(rectified, out_channels)
                    output = batchnorm(convolved)
                    fg_layers.append(output)

            merged_layers = [tf.concat([layers[-1], fg_layers[-1]], axis=3)]

            layer_specs = [
                self.ngf *
                4,  # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
                self.ngf *
                8,  # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
                self.ngf *
                8,  # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
                self.ngf *
                8,  # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
            ]

            for out_channels in layer_specs:
                with tf.variable_scope("merged_encoder_%d" %
                                       (len(merged_layers) + 1)):
                    rectified = lrelu(merged_layers[-1], 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                    convolved = gen_conv(rectified, out_channels)
                    output = batchnorm(convolved)
                    merged_layers.append(output)

            layer_specs = [
                (
                    self.ngf * 8
                ),  # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
                (
                    self.ngf * 8
                ),  # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
                (
                    self.ngf * 4
                ),  # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
                (
                    self.ngf * 4
                ),  # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
            ]

            num_encoder_layers = len(merged_layers)
            for decoder_layer, out_channels in enumerate(layer_specs):
                skip_layer = num_encoder_layers - decoder_layer - 1
                with tf.variable_scope("merged_decoder_%d" % (skip_layer + 1)):
                    if decoder_layer == 0:
                        # first decoder layer doesn't have skip connections
                        # since it is directly connected to the skip_layer
                        input = merged_layers[-1]
                    else:
                        input = tf.concat(
                            [merged_layers[-1], merged_layers[skip_layer]],
                            axis=3)

                    rectified = tf.nn.relu(input)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
                    output = gen_deconv(rectified, out_channels)
                    output = batchnorm(output)

                    merged_layers.append(output)

            layer_specs = [
                (
                    self.ngf * 2
                ),  # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
                (
                    self.ngf * 2
                ),  # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
                (
                    self.ngf
                ),  # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
            ]

            num_encoder_layers = len(layers)
            for decoder_layer, out_channels in enumerate(layer_specs):
                skip_layer = num_encoder_layers - decoder_layer - 1
                with tf.variable_scope("merged2_decoder_%d" %
                                       (skip_layer + 1)):
                    input = tf.concat([merged_layers[-1], layers[skip_layer]],
                                      axis=3)

                    rectified = tf.nn.relu(input)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
                    output = gen_deconv(rectified, out_channels)
                    output = batchnorm(output)

                    merged_layers.append(output)

            # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
            with tf.variable_scope("decoder_1"):
                input = tf.concat([merged_layers[-1], layers[0]], axis=3)
                rectified = tf.nn.relu(input)
                output = gen_deconv(rectified, generator_outputs_channels)
                output = tf.tanh(output)
                layers.append(output)

            return layers[-1]

        nodes = {}
        with tf.variable_scope("generator"):
            output = create_generator(inputs,
                                      fg_inputs[..., :3],
                                      generator_outputs_channels=4)
            rgb = output[:, :, :, :3]
            alpha = (output[:, :, :, 3:] + 1) / 2
            alpha = tf.tile(alpha, [1, 1, 1, 3])
            output = rgb * alpha + targets * (1 - alpha)
            output_fg = rgb * alpha + alpha - 1

            nodes.update({'Outputs': output})
            nodes.update({'Alphas': alpha})
            nodes.update({'Outputs_FG': output_fg})

        if (trainable):
            # create two copies of discriminator, one for real pairs and one for fake pairs
            # they share the same underlying variables
            with tf.name_scope("real_discriminator"):
                with tf.variable_scope("discriminator"):
                    predict_real = create_discriminator(
                        inputs[..., 3:], fg_inputs[..., 3:])
                with tf.variable_scope("discriminator", reuse=True):
                    predict_real2 = create_discriminator(
                        inputs[..., :3], fg_inputs[..., :3])
                    predict_real = (predict_real + predict_real2) / 2
                    nodes.update({'Predict_real': predict_real})

            with tf.name_scope("fake_discriminator"):
                with tf.variable_scope("discriminator", reuse=True):
                    predict_fake = create_discriminator(
                        inputs[..., 3:], output_fg)
                    nodes.update({'Predict_fake': predict_fake})

            # with tf.name_scope("real_target_discriminator"):
            #   with tf.variable_scope("target_discriminator"):
            #     predict_real = create_target_discriminator(fg_inputs[:, 384:, :, 3:])
            #     nodes.update({'Predict_real_target': predict_real})

            # with tf.name_scope("fake_target_discriminator"):
            #   with tf.variable_scope("target_discriminator", reuse=True):
            #     predict_fake = create_target_discriminator(output_fg[:, 384:, :, :])
            #     nodes.update({'Predict_fake_target': predict_fake})

            with tf.name_scope("vgg_perceptual"):
                with slim.arg_scope(vgg.vgg_arg_scope()):

                    f1, f2, f3, f4, exclude = vgg.vgg_16(
                        tf.concat([fg_inputs[..., 3:], output_fg], axis=0))
                    gen_f, img_f = tf.split(f3, 2, 0)
                    content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(
                        tf.size(gen_f))

                    vgg_vars = slim.get_variables_to_restore(
                        include=['vgg_16'], exclude=exclude)
                    init_fn = slim.assign_from_checkpoint_fn(
                        self.vgg_model_path, vgg_vars)
                    init_fn(self.sess)
                    nodes.update({'Perceptual_loss': content_loss})

        return nodes