示例#1
0
    def __init__(self,
                 relu_targets=None,
                 vgg_path=None,
                 alpha=0.7,
                 beta=0.5,
                 ss_alpha=0.7,
                 use_adain=False):
        '''
            Args:
                mode: 'train' or 'test'. If 'train' then training & summary ops will be added to the graph
                relu_targets: List of relu target layers corresponding to decoder checkpoints
                vgg_path: Normalised VGG19 .t7 path
        '''
        super().__init__()
        self.vgg_model = vgg_from_t7(vgg_path, target_layer=relu_targets)
        self.relu_targets = relu_targets
        self.alpha = alpha
        self.beta = beta
        self.ss_alpha = ss_alpha
        self.use_adain = use_adain
        self.encoders = []
        self.decoders = []

        for relu_target in relu_targets:
            self.encoders.append(self.build_encoder(relu_target))
            self.decoders.append(self.build_decoder(relu_target))
示例#2
0
    def __init__(self, mode='train', relu_targets=['relu5_1','relu4_1','relu3_1','relu2_1','relu1_1'], vgg_path=None,  *args, **kwargs):
        '''
            Args:
                mode: 'train' or 'test'. If 'train' then training & summary ops will be added to the graph
                relu_targets: List of relu target layers corresponding to decoder checkpoints
                vgg_path: Normalised VGG19 .t7 path
        '''
        self.mode = mode

        self.style_input = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='style_img')

        # Flag for applying WCT, should only be True for test mode. Setting to False will pass through content encoding.
        self.apply_wct = tf.placeholder_with_default(tf.constant(False), shape=[])

        self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')

        self.encoder_decoders = []
        
        ### Build the graph ###
        
        # Load shared VGG model up to deepest target layer
        with tf.name_scope('vgg_encoder'):
            deepest_target = sorted(relu_targets)[-1]
            print('Loading VGG up to layer',deepest_target)
            self.vgg_model = vgg_from_t7(vgg_path, target_layer=deepest_target)
            print(self.vgg_model.summary())

        if self.mode == 'train':
            style_encodings = [None]  # Style encoding is not needed for train stage
        else:
            # Build model to extract intermediate relu layers for style img to be used in multi-level pipeline
            with tf.name_scope('style_encoder'):
                style_encoding_layers = [self.vgg_model.get_layer(relu).output for relu in relu_targets]
                style_encoder_model = Model(inputs=self.vgg_model.input, outputs=style_encoding_layers)
                style_encodings = style_encoder_model(self.style_input)

            if len(relu_targets) == 1:
                style_encodings = [style_encodings]

        # Build enc/decs for each target relu and hook the out of each decoder up to subsequent encoder input
        for i, (relu, style_encoded) in enumerate(zip(relu_targets, style_encodings)):
            print('Building encoder/decoder for relu target',relu)
            
            if i == 0:
                # Input tensor will be a placeholder for the first encoder/decoder
                input_tensor = None
            else:
                # Input to intermediate levels is the output from previous decoder
                input_tensor = clip(self.encoder_decoders[-1].decoded)
            
            enc_dec = self.build_model(relu, input_tensor=input_tensor, style_encoded_tensor=style_encoded, **kwargs)
        
            self.encoder_decoders.append(enc_dec)

        # Hooks for placeholder input for first encoder and final output from last decoder
        self.content_input  = self.encoder_decoders[0].content_input
        self.decoded_output = self.encoder_decoders[-1].decoded
示例#3
0
    def build_model(self, vgg_weights):
        self.content_imgs = tf.placeholder(shape=(None, None, None, 3),
                                           name='content_imgs',
                                           dtype=tf.float32)
        self.style_imgs = tf.placeholder(shape=(None, None, None, 3),
                                         name='style_imgs',
                                         dtype=tf.float32)
        self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')

        ### Load shared VGG model up to relu4_1
        with tf.name_scope('encoder'):
            self.vgg_model = vgg_from_t7(vgg_weights, target_layer='relu4_1')
        print(self.vgg_model.summary())

        ### Build encoders for content layer
        with tf.name_scope('content_layer_encoder'):
            # Build content layer encoding model
            content_layer = self.vgg_model.get_layer('relu4_1').output
            self.content_encoder_model = Model(inputs=self.vgg_model.input,
                                               outputs=content_layer)

            # Setup content layer encodings for content/style images
            self.content_encoded = self.content_encoder_model(
                self.content_imgs)
            self.style_encoded = self.content_encoder_model(self.style_imgs)

            # Apply affine Adaptive Instance Norm transform
            self.adain_encoded = adain(self.content_encoded,
                                       self.style_encoded, self.alpha)

        ### Build decoder
        with tf.name_scope('decoder'):
            n_channels = self.adain_encoded.get_shape()[-1].value
            self.decoder_model = self.build_decoder(input_shape=(None, None,
                                                                 n_channels))

            # Setup a placeholder that defaults to the adain tensor but can be substituted with a feed_dict. Needed for interpolation.
            self.adain_encoded_pl = tf.placeholder_with_default(
                self.adain_encoded, shape=self.adain_encoded.get_shape())

            # Stylized/decoded output from AdaIN transformed encoding
            self.decoded = self.decoder_model(
                Lambda(lambda x: x)
                (self.adain_encoded_pl))  # Lambda converts TF tensor to Keras

        # Content layer encoding for stylized out
        self.decoded_encoded = self.content_encoder_model(self.decoded)
示例#4
0
    def __init__(self, mode='train', relu_targets=['relu5_1','relu4_1','relu3_1','relu2_1','relu1_1'], vgg_path=None,  
                 *args, **kwargs):
        '''
            Args:
                mode: 'train' or 'test'. If 'train' then training & summary ops will be added to the graph
                relu_targets: List of relu target layers corresponding to decoder checkpoints
                vgg_path: Normalised VGG19 .t7 path
        '''
        self.mode = mode

        self.style_input = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='style_img')

        self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')
        
        # Style swap settings
        self.swap5 = tf.placeholder_with_default(tf.constant(False), shape=[])
        self.ss_alpha = tf.placeholder_with_default(.7, shape=[], name='ss_alpha')

        # Flag to use AdaIN instead of WCT
        self.use_adain = tf.placeholder_with_default(tf.constant(False), shape=[])
        
        self.encoder_decoders = []
        
        ### Build the graph ###
        
        # Load shared VGG model up to deepest target layer
        with tf.name_scope('vgg_encoder'):
            deepest_target = sorted(relu_targets)[-1]
            print('Loading VGG up to layer',deepest_target)
            self.vgg_model = vgg_from_t7(vgg_path, target_layer=deepest_target)
            print(self.vgg_model.summary())

        if self.mode == 'train':
            style_encodings = [None]  # Style encoding is not needed for train stage
        else:
            # Build model to extract intermediate relu layers for style img to be used in multi-level pipeline
            with tf.name_scope('style_encoder'):
                style_encoding_layers = [self.vgg_model.get_layer(relu).output for relu in relu_targets]
                style_encoder_model = Model(inputs=self.vgg_model.input, outputs=style_encoding_layers)
                style_encodings = style_encoder_model(self.style_input)

            if len(relu_targets) == 1:
                style_encodings = [style_encodings]

        # Build enc/decs for each target relu and hook the out of each decoder up to subsequent encoder input
        for i, (relu, style_encoded) in enumerate(zip(relu_targets, style_encodings)):
            print('Building encoder/decoder for relu target',relu)
            
            if i == 0:
                # Input tensor will be a placeholder for the first encoder/decoder
                input_tensor = None
            else:
                # Input to intermediate levels is the output from previous decoder
                input_tensor = clip(self.encoder_decoders[-1].decoded)
            
            enc_dec = self.build_model(relu, input_tensor=input_tensor, style_encoded_tensor=style_encoded, **kwargs)
        
            self.encoder_decoders.append(enc_dec)

        # Hooks for placeholder input for first encoder and final output from last decoder
        self.content_input  = self.encoder_decoders[0].content_input
        self.decoded_output = self.encoder_decoders[-1].decoded
示例#5
0
    def __init__(self,
                 mode='train',
                 relu_targets=[
                     'relu5_1', 'relu4_1', 'relu3_1', 'relu2_1', 'relu1_1'
                 ],
                 vgg_path=None,
                 use_wavelet_pooling=False,
                 *args,
                 **kwargs):
        '''
            Args:
                mode: 'train' or 'test'. If 'train' then training & summary ops will be added to the graph
                relu_targets: List of relu target layers corresponding to decoder checkpoints
                vgg_path: Normalised VGG19 .t7 path
        '''
        self.mode = mode
        self.vgg_path = vgg_path
        self.relu_targets = relu_targets

        #### 3 -> 4
        #self.style_input = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='style_img')
        #self.style_input = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.,0.]]]]), shape=(None, None, None, 4), name='style_img')
        self.style_input = tf.placeholder_with_default(
            tf.constant([[[[0., 0., 0., 0., 0.]]]]),
            shape=(None, None, None, 5),
            name='style_img')

        self.use_wavelet_pooling = use_wavelet_pooling

        self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')

        # Style swap settings
        self.swap5 = tf.placeholder_with_default(tf.constant(False), shape=[])
        self.ss_alpha = tf.placeholder_with_default(.7,
                                                    shape=[],
                                                    name='ss_alpha')

        # Flag to use AdaIN instead of WCT
        self.use_adain = tf.placeholder_with_default(tf.constant(False),
                                                     shape=[])

        self.encoder_decoders = []

        if self.mode == "train":
            style_encodings = [None]
        else:
            #with tf.name_scope("style_encoder"):
            with tf.variable_scope("vgg_encoder", reuse=tf.AUTO_REUSE):
                deepest_target = sorted(relu_targets)[-1]
                self.deepest_target = deepest_target
                style_input = self.style_input[..., :-1]
                style_mask = self.style_input[..., -1]
                self.style_mask = style_mask

                vgg_outputs, indices_list, relu_targets_dict = vgg_from_t7(
                    vgg_path,
                    target_layer=deepest_target,
                    inp=style_input,
                    use_wavelet_pooling=self.use_wavelet_pooling)
                style_encoding_layers = [
                    relu_targets_dict[relu] for relu in relu_targets
                ]
                style_encodings = style_encoding_layers

        for i, (relu,
                style_encoded) in enumerate(zip(relu_targets,
                                                style_encodings)):
            print('Building encoder/decoder for relu target', relu)

            if i == 0:
                # Input tensor will be a placeholder for the first encoder/decoder
                input_tensor = None
            else:
                # Input to intermediate levels is the output from previous decoder
                input_tensor = clip(self.encoder_decoders[-1].decoded)

            enc_dec = self.build_model(relu,
                                       input_tensor=input_tensor,
                                       style_encoded_tensor=style_encoded,
                                       encoder_indices=None,
                                       **kwargs)

            self.encoder_decoders.append(enc_dec)
示例#6
0
    def build_model(self,
                    relu_target,
                    input_tensor,
                    style_encoded_tensor=None,
                    batch_size=8,
                    feature_weight=1,
                    pixel_weight=1,
                    tv_weight=1.0,
                    learning_rate=1e-4,
                    lr_decay=5e-5,
                    ss_patch_size=3,
                    ss_stride=1,
                    encoder_indices=None):
        '''Build the EncoderDecoder architecture for a given relu layer.

            Args:
                relu_target: Layer of VGG to decode from
                input_tensor: If None then a placeholder will be created, else use this tensor as the input to the encoder
                style_encoded_tensor: Tensor for style image features at the same relu layer. Used only at test time.
                batch_size: Batch size for training
                feature_weight: Float weight for feature reconstruction loss
                pixel_weight: Float weight for pixel reconstruction loss
                tv_weight: Float weight for total variation loss
                learning_rate: Float LR
                lr_decay: Float linear decay for training
            Returns:
                EncoderDecoder namedtuple with input/encoding/output tensors and ops for training.
        '''
        with tf.name_scope('encoder_decoder_' + relu_target):
            ### Build encoder for reluX_1
            #with tf.name_scope('content_encoder_'+relu_target):
            with tf.variable_scope("vgg_encoder", reuse=tf.AUTO_REUSE):
                if input_tensor is None:
                    # This is the first level encoder that takes original content imgs
                    #### 3 -> 4
                    #content_imgs = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='content_imgs')
                    #content_imgs = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.,0.]]]]), shape=(None, None, None, 4), name='content_imgs')
                    content_imgs = tf.placeholder_with_default(
                        tf.constant([[[[0., 0., 0., 0., 0.]]]]),
                        shape=(None, None, None, 5),
                        name='content_imgs')
                else:
                    # This is an intermediate-level encoder that takes output tensor from previous level as input
                    content_imgs = input_tensor

                deepest_target = sorted(self.relu_targets)[-1]
                self.deepest_target = deepest_target
                vgg_inputs = content_imgs[..., :-1]
                content_mask = content_imgs[..., -1]

                #vgg_inputs = tf.zeros(shape=[8, 512, 512, 3], dtype=tf.float32)
                vgg_outputs, indices_list, relu_targets_dict = vgg_from_t7(
                    self.vgg_path,
                    target_layer=deepest_target,
                    inp=vgg_inputs,
                    use_wavelet_pooling=self.use_wavelet_pooling)

            content_layer = relu_targets_dict[relu_target]
            content_encoded = content_layer
            encoder_indices = indices_list

            ### Build style encoder & WCT if test mode
            if self.mode != 'train':
                with tf.name_scope('wct_' + relu_target):
                    assert relu_target != "relu5_1"
                    if relu_target == 'relu5_1':
                        # Apply style swap on relu5_1 encodings if self.swap5 flag is set
                        # Use AdaIN as transfer op instead of WCT if self.use_adain is set
                        # Otherwise perform WCT
                        decoder_input = tf.case(
                            [(self.swap5, lambda: wct_style_swap(
                                content_encoded, style_encoded_tensor, self.
                                ss_alpha, ss_patch_size, ss_stride)),
                             (self.use_adain,
                              lambda: adain(content_encoded,
                                            style_encoded_tensor, self.alpha))
                             ],
                            default=lambda: wct_tf(content_encoded,
                                                   style_encoded_tensor, self.
                                                   alpha))
                    else:
                        content_encoded_list, content_one_hot_list = feature_map_lookedup(
                            content_encoded, mask=content_mask)

                        tf.identity(tf.stack(content_encoded_list, axis=0),
                                    name="content_encoded")
                        tf.identity(tf.stack(content_one_hot_list, axis=0),
                                    name="content_one_hot")

                        style_encoded_list, style_one_hot_list = feature_map_lookedup(
                            style_encoded_tensor, mask=self.style_mask)

                        tf.identity(tf.stack(style_encoded_list, axis=0),
                                    name="style_encoded")
                        tf.identity(tf.stack(style_one_hot_list, axis=0),
                                    name="style_one_hot")

                        decoder_input_list = []
                        for i in range(len(content_encoded_list)):
                            single_content_encoded = content_encoded_list[i]
                            single_style_encoded_tensor = style_encoded_list[i]
                            single_decoder_input = tf.cond(
                                self.use_adain, lambda: adain(
                                    single_content_encoded,
                                    single_style_encoded_tensor, self.alpha),
                                lambda: wct_tf(single_content_encoded,
                                               single_style_encoded_tensor,
                                               self.alpha))
                            #### [batch, to_h, to_w, cdim]
                            single_content_mask = tf.tile(
                                tf.image.resize_images(
                                    content_one_hot_list[i][..., -2:-1],
                                    (tf.shape(single_decoder_input)[1],
                                     tf.shape(single_decoder_input)[2]),
                                    method=1),
                                [1, 1, 1,
                                 tf.shape(single_decoder_input)[3]])

                            decoder_input_list.append(
                                single_decoder_input *
                                tf.cast(single_content_mask, tf.float32))

                        decoder_input = reduce(lambda a, b: a + b,
                                               decoder_input_list)

            else:  # In train mode we're trying to reconstruct from the encoding, so pass along unchanged
                decoder_input = content_encoded

            ### Build decoder
            with tf.name_scope('decoder_' + relu_target):
                n_channels = content_encoded.get_shape()[-1].value
                Bc, Hc, Wc, Cc = tf.unstack(tf.shape(decoder_input))
                decoder_input = tf.reshape(decoder_input,
                                           [Bc, Hc, Wc, n_channels])
                decoder_input_wrapped, decoded = self.build_decoder(
                    decoder_input,
                    input_shape=(None, None, n_channels),
                    relu_target=relu_target,
                    encoder_indices=encoder_indices,
                    use_wavelet_pooling=self.use_wavelet_pooling)

            # Content layer encoding for stylized out
            with tf.variable_scope("vgg_encoder", reuse=tf.AUTO_REUSE):
                #### should add seg into decoded
                seg_input = content_imgs[..., -2:-1]
                decoded_input = tf.concat([decoded, seg_input], axis=-1)
                decoded_encoded, _, _ = vgg_from_t7(
                    self.vgg_path,
                    target_layer=self.deepest_target,
                    inp=decoded_input,
                    use_wavelet_pooling=self.use_wavelet_pooling)

        if self.mode == 'train':  # Train & summary ops only needed for training phase
            ### Losses
            with tf.name_scope('losses_' + relu_target):
                # Feature loss between encodings of original & reconstructed

                feature_loss = feature_weight * mse(decoded_encoded,
                                                    content_encoded)

                content_imgs_sliced = content_imgs
                if int(content_imgs.get_shape()[-1]) != 3:
                    content_imgs_sliced = content_imgs[..., :3]
                # Pixel reconstruction loss between decoded/reconstructed img and original
                pixel_loss = pixel_weight * mse(decoded, content_imgs_sliced)

                # Total Variation loss
                if tv_weight > 0:
                    tv_loss = tv_weight * tf.reduce_mean(
                        tf.image.total_variation(decoded))
                else:
                    tv_loss = tf.constant(0.)

                total_loss = 1.0 * feature_loss + 1.0 * pixel_loss + tv_loss

            with tf.name_scope('train_' + relu_target):
                global_step = tf.Variable(0,
                                          name='global_step_train',
                                          trainable=False)
                # self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, 0.96, staircase=False)
                learning_rate = torch_decay(learning_rate, global_step,
                                            lr_decay)
                d_optimizer = tf.train.AdamOptimizer(learning_rate,
                                                     beta1=0.9,
                                                     beta2=0.999)
                d_vars = [
                    var for var in tf.trainable_variables()
                    if 'vgg_encoder' not in var.name
                ]

                train_op = d_optimizer.minimize(total_loss,
                                                var_list=d_vars,
                                                global_step=global_step)

            ### Loss & image summaries
            with tf.name_scope('summary_' + relu_target):
                feature_loss_summary = tf.summary.scalar(
                    'feature_loss', feature_loss)
                pixel_loss_summary = tf.summary.scalar('pixel_loss',
                                                       pixel_loss)
                tv_loss_summary = tf.summary.scalar('tv_loss', tv_loss)
                total_loss_summary = tf.summary.scalar('total_loss',
                                                       total_loss)

                content_imgs_summary = tf.summary.image(
                    'content_imgs', content_imgs_sliced)
                decoded_images_summary = tf.summary.image(
                    'decoded_images', clip(decoded))

                for var in d_vars:
                    tf.summary.histogram(var.op.name, var)

                summary_op = tf.summary.merge_all()
        else:
            # For inference set unnneeded ops to None
            pixel_loss, feature_loss, tv_loss, total_loss, train_op, global_step, learning_rate, summary_op = [
                None
            ] * 8

        # Put it all together
        encoder_decoder = EncoderDecoder(content_input=content_imgs,
                                         content_encoded=content_encoded,
                                         style_encoded=style_encoded_tensor,
                                         decoder_input=decoder_input,
                                         decoded=decoded,
                                         decoded_encoded=decoded_encoded,
                                         pixel_loss=pixel_loss,
                                         feature_loss=feature_loss,
                                         tv_loss=tv_loss,
                                         total_loss=total_loss,
                                         train_op=train_op,
                                         global_step=global_step,
                                         learning_rate=learning_rate,
                                         summary_op=summary_op)

        return encoder_decoder