Esempio n. 1
0
    def build_model(self, 
                    relu_target,
                    input_tensor,
                    style_encoded_tensor=None,
                    batch_size=8,
                    feature_weight=1,
                    pixel_weight=1,
                    tv_weight=0,
                    learning_rate=1e-4,
                    lr_decay=5e-5,
                    ss_patch_size=3,
                    ss_stride=1):
        '''Build the EncoderDecoder architecture for a given relu layer.

            Args:
                relu_target: Layer of VGG to decode from
                input_tensor: If None then a placeholder will be created, else use this tensor as the input to the encoder
                style_encoded_tensor: Tensor for style image features at the same relu layer. Used only at test time.
                batch_size: Batch size for training
                feature_weight: Float weight for feature reconstruction loss
                pixel_weight: Float weight for pixel reconstruction loss
                tv_weight: Float weight for total variation loss
                learning_rate: Float LR
                lr_decay: Float linear decay for training
            Returns:
                EncoderDecoder namedtuple with input/encoding/output tensors and ops for training.
        '''
        with tf.name_scope('encoder_decoder_'+relu_target):

            ### Build encoder for reluX_1
            with tf.name_scope('content_encoder_'+relu_target):
                if input_tensor is None:  
                    # This is the first level encoder that takes original content imgs
                    content_imgs = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='content_imgs')
                else:                     
                    # This is an intermediate-level encoder that takes output tensor from previous level as input
                    content_imgs = input_tensor  

                # Build content layer encoding model
                content_layer = self.vgg_model.get_layer(relu_target).output
                content_encoder_model = Model(inputs=self.vgg_model.input, outputs=content_layer)

                # Setup content layer encodings for content images
                content_encoded = content_encoder_model(content_imgs)
 
            ### Build style encoder & WCT if test mode
            if self.mode != 'train':                
                with tf.name_scope('wct_'+relu_target):
                    if relu_target == 'relu5_1':
                        # Apply style swap on relu5_1 encodings if self.swap5 flag is set
                        # Use AdaIN as transfer op instead of WCT if self.use_adain is set
                        # Otherwise perform WCT
                        decoder_input = tf.case([(self.swap5, lambda: wct_style_swap(content_encoded,
                                                                                    style_encoded_tensor,
                                                                                    self.ss_alpha,
                                                                                    ss_patch_size, 
                                                                                    ss_stride)),
                                                (self.use_adain, lambda: adain(content_encoded, style_encoded_tensor, self.alpha))],
                                                default=lambda: wct_tf(content_encoded, style_encoded_tensor, self.alpha))
                    else:
                        decoder_input = tf.cond(self.use_adain, 
                                                lambda: adain(content_encoded, style_encoded_tensor, self.alpha),
                                                lambda: wct_tf(content_encoded, style_encoded_tensor, self.alpha))

                    
            else: # In train mode we're trying to reconstruct from the encoding, so pass along unchanged
                decoder_input = content_encoded

            ### Build decoder
            with tf.name_scope('decoder_'+relu_target):
                n_channels = content_encoded.get_shape()[-1].value
                decoder_model = self.build_decoder(input_shape=(None, None, n_channels), relu_target=relu_target)

                # Wrap the decoder_input tensor so that it has the proper shape for decoder_model
                decoder_input_wrapped = tf.placeholder_with_default(decoder_input, shape=[None,None,None,n_channels])

                # Reconstruct/decode from encoding
                decoded = decoder_model(Lambda(lambda x: x)(decoder_input_wrapped)) # Lambda converts TF tensor to Keras

            # Content layer encoding for stylized out
            decoded_encoded = content_encoder_model(decoded)

        if self.mode == 'train':  # Train & summary ops only needed for training phase
            ### Losses
            with tf.name_scope('losses_'+relu_target):
                # Feature loss between encodings of original & reconstructed
                feature_loss = feature_weight * mse(decoded_encoded, content_encoded)

                # Pixel reconstruction loss between decoded/reconstructed img and original
                pixel_loss = pixel_weight * mse(decoded, content_imgs)

                # Total Variation loss
                if tv_weight > 0:
                    tv_loss = tv_weight * tf.reduce_mean(tf.image.total_variation(decoded))
                else:
                    tv_loss = tf.constant(0.)

                total_loss = feature_loss + pixel_loss + tv_loss

            ### Training ops
            with tf.name_scope('train_'+relu_target):
                global_step = tf.Variable(0, name='global_step_train', trainable=False)
                # self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, 0.96, staircase=False)
                learning_rate = torch_decay(learning_rate, global_step, lr_decay)
                d_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999)

                # Only train decoder vars, encoder is frozen
                d_vars = [var for var in tf.trainable_variables() if 'decoder_'+relu_target in var.name]

                train_op = d_optimizer.minimize(total_loss, var_list=d_vars, global_step=global_step)

            ### Loss & image summaries
            with tf.name_scope('summary_'+relu_target):
                feature_loss_summary = tf.summary.scalar('feature_loss', feature_loss)
                pixel_loss_summary = tf.summary.scalar('pixel_loss', pixel_loss)
                tv_loss_summary = tf.summary.scalar('tv_loss', tv_loss)
                total_loss_summary = tf.summary.scalar('total_loss', total_loss)

                content_imgs_summary = tf.summary.image('content_imgs', content_imgs)
                decoded_images_summary = tf.summary.image('decoded_images', clip(decoded))
                
                for var in d_vars:
                    tf.summary.histogram(var.op.name, var)

                summary_op = tf.summary.merge_all()
        else:
            # For inference set unnneeded ops to None
            pixel_loss, feature_loss, tv_loss, total_loss, train_op, global_step, learning_rate, summary_op = [None]*8

        # Put it all together
        encoder_decoder = EncoderDecoder(content_input=content_imgs, 
                                         content_encoder_model=content_encoder_model,
                                         content_encoded=content_encoded,
                                         style_encoded=style_encoded_tensor,
                                         decoder_input=decoder_input,
                                         decoder_model=decoder_model,
                                         decoded=decoded,
                                         decoded_encoded=decoded_encoded,
                                         pixel_loss=pixel_loss,
                                         feature_loss=feature_loss,
                                         tv_loss=tv_loss,
                                         total_loss=total_loss,
                                         train_op=train_op,
                                         global_step=global_step,
                                         learning_rate=learning_rate,
                                         summary_op=summary_op)
        
        return encoder_decoder
Esempio n. 2
0
    def build_model(self,
                    relu_target,
                    input_tensor,
                    style_encoded_tensor=None,
                    batch_size=8,
                    feature_weight=1,
                    pixel_weight=1,
                    tv_weight=1.0,
                    learning_rate=1e-4,
                    lr_decay=5e-5,
                    ss_patch_size=3,
                    ss_stride=1,
                    encoder_indices=None):
        '''Build the EncoderDecoder architecture for a given relu layer.

            Args:
                relu_target: Layer of VGG to decode from
                input_tensor: If None then a placeholder will be created, else use this tensor as the input to the encoder
                style_encoded_tensor: Tensor for style image features at the same relu layer. Used only at test time.
                batch_size: Batch size for training
                feature_weight: Float weight for feature reconstruction loss
                pixel_weight: Float weight for pixel reconstruction loss
                tv_weight: Float weight for total variation loss
                learning_rate: Float LR
                lr_decay: Float linear decay for training
            Returns:
                EncoderDecoder namedtuple with input/encoding/output tensors and ops for training.
        '''
        with tf.name_scope('encoder_decoder_' + relu_target):
            ### Build encoder for reluX_1
            #with tf.name_scope('content_encoder_'+relu_target):
            with tf.variable_scope("vgg_encoder", reuse=tf.AUTO_REUSE):
                if input_tensor is None:
                    # This is the first level encoder that takes original content imgs
                    #### 3 -> 4
                    #content_imgs = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='content_imgs')
                    #content_imgs = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.,0.]]]]), shape=(None, None, None, 4), name='content_imgs')
                    content_imgs = tf.placeholder_with_default(
                        tf.constant([[[[0., 0., 0., 0., 0.]]]]),
                        shape=(None, None, None, 5),
                        name='content_imgs')
                else:
                    # This is an intermediate-level encoder that takes output tensor from previous level as input
                    content_imgs = input_tensor

                deepest_target = sorted(self.relu_targets)[-1]
                self.deepest_target = deepest_target
                vgg_inputs = content_imgs[..., :-1]
                content_mask = content_imgs[..., -1]

                #vgg_inputs = tf.zeros(shape=[8, 512, 512, 3], dtype=tf.float32)
                vgg_outputs, indices_list, relu_targets_dict = vgg_from_t7(
                    self.vgg_path,
                    target_layer=deepest_target,
                    inp=vgg_inputs,
                    use_wavelet_pooling=self.use_wavelet_pooling)

            content_layer = relu_targets_dict[relu_target]
            content_encoded = content_layer
            encoder_indices = indices_list

            ### Build style encoder & WCT if test mode
            if self.mode != 'train':
                with tf.name_scope('wct_' + relu_target):
                    assert relu_target != "relu5_1"
                    if relu_target == 'relu5_1':
                        # Apply style swap on relu5_1 encodings if self.swap5 flag is set
                        # Use AdaIN as transfer op instead of WCT if self.use_adain is set
                        # Otherwise perform WCT
                        decoder_input = tf.case(
                            [(self.swap5, lambda: wct_style_swap(
                                content_encoded, style_encoded_tensor, self.
                                ss_alpha, ss_patch_size, ss_stride)),
                             (self.use_adain,
                              lambda: adain(content_encoded,
                                            style_encoded_tensor, self.alpha))
                             ],
                            default=lambda: wct_tf(content_encoded,
                                                   style_encoded_tensor, self.
                                                   alpha))
                    else:
                        content_encoded_list, content_one_hot_list = feature_map_lookedup(
                            content_encoded, mask=content_mask)

                        tf.identity(tf.stack(content_encoded_list, axis=0),
                                    name="content_encoded")
                        tf.identity(tf.stack(content_one_hot_list, axis=0),
                                    name="content_one_hot")

                        style_encoded_list, style_one_hot_list = feature_map_lookedup(
                            style_encoded_tensor, mask=self.style_mask)

                        tf.identity(tf.stack(style_encoded_list, axis=0),
                                    name="style_encoded")
                        tf.identity(tf.stack(style_one_hot_list, axis=0),
                                    name="style_one_hot")

                        decoder_input_list = []
                        for i in range(len(content_encoded_list)):
                            single_content_encoded = content_encoded_list[i]
                            single_style_encoded_tensor = style_encoded_list[i]
                            single_decoder_input = tf.cond(
                                self.use_adain, lambda: adain(
                                    single_content_encoded,
                                    single_style_encoded_tensor, self.alpha),
                                lambda: wct_tf(single_content_encoded,
                                               single_style_encoded_tensor,
                                               self.alpha))
                            #### [batch, to_h, to_w, cdim]
                            single_content_mask = tf.tile(
                                tf.image.resize_images(
                                    content_one_hot_list[i][..., -2:-1],
                                    (tf.shape(single_decoder_input)[1],
                                     tf.shape(single_decoder_input)[2]),
                                    method=1),
                                [1, 1, 1,
                                 tf.shape(single_decoder_input)[3]])

                            decoder_input_list.append(
                                single_decoder_input *
                                tf.cast(single_content_mask, tf.float32))

                        decoder_input = reduce(lambda a, b: a + b,
                                               decoder_input_list)

            else:  # In train mode we're trying to reconstruct from the encoding, so pass along unchanged
                decoder_input = content_encoded

            ### Build decoder
            with tf.name_scope('decoder_' + relu_target):
                n_channels = content_encoded.get_shape()[-1].value
                Bc, Hc, Wc, Cc = tf.unstack(tf.shape(decoder_input))
                decoder_input = tf.reshape(decoder_input,
                                           [Bc, Hc, Wc, n_channels])
                decoder_input_wrapped, decoded = self.build_decoder(
                    decoder_input,
                    input_shape=(None, None, n_channels),
                    relu_target=relu_target,
                    encoder_indices=encoder_indices,
                    use_wavelet_pooling=self.use_wavelet_pooling)

            # Content layer encoding for stylized out
            with tf.variable_scope("vgg_encoder", reuse=tf.AUTO_REUSE):
                #### should add seg into decoded
                seg_input = content_imgs[..., -2:-1]
                decoded_input = tf.concat([decoded, seg_input], axis=-1)
                decoded_encoded, _, _ = vgg_from_t7(
                    self.vgg_path,
                    target_layer=self.deepest_target,
                    inp=decoded_input,
                    use_wavelet_pooling=self.use_wavelet_pooling)

        if self.mode == 'train':  # Train & summary ops only needed for training phase
            ### Losses
            with tf.name_scope('losses_' + relu_target):
                # Feature loss between encodings of original & reconstructed

                feature_loss = feature_weight * mse(decoded_encoded,
                                                    content_encoded)

                content_imgs_sliced = content_imgs
                if int(content_imgs.get_shape()[-1]) != 3:
                    content_imgs_sliced = content_imgs[..., :3]
                # Pixel reconstruction loss between decoded/reconstructed img and original
                pixel_loss = pixel_weight * mse(decoded, content_imgs_sliced)

                # Total Variation loss
                if tv_weight > 0:
                    tv_loss = tv_weight * tf.reduce_mean(
                        tf.image.total_variation(decoded))
                else:
                    tv_loss = tf.constant(0.)

                total_loss = 1.0 * feature_loss + 1.0 * pixel_loss + tv_loss

            with tf.name_scope('train_' + relu_target):
                global_step = tf.Variable(0,
                                          name='global_step_train',
                                          trainable=False)
                # self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, 0.96, staircase=False)
                learning_rate = torch_decay(learning_rate, global_step,
                                            lr_decay)
                d_optimizer = tf.train.AdamOptimizer(learning_rate,
                                                     beta1=0.9,
                                                     beta2=0.999)
                d_vars = [
                    var for var in tf.trainable_variables()
                    if 'vgg_encoder' not in var.name
                ]

                train_op = d_optimizer.minimize(total_loss,
                                                var_list=d_vars,
                                                global_step=global_step)

            ### Loss & image summaries
            with tf.name_scope('summary_' + relu_target):
                feature_loss_summary = tf.summary.scalar(
                    'feature_loss', feature_loss)
                pixel_loss_summary = tf.summary.scalar('pixel_loss',
                                                       pixel_loss)
                tv_loss_summary = tf.summary.scalar('tv_loss', tv_loss)
                total_loss_summary = tf.summary.scalar('total_loss',
                                                       total_loss)

                content_imgs_summary = tf.summary.image(
                    'content_imgs', content_imgs_sliced)
                decoded_images_summary = tf.summary.image(
                    'decoded_images', clip(decoded))

                for var in d_vars:
                    tf.summary.histogram(var.op.name, var)

                summary_op = tf.summary.merge_all()
        else:
            # For inference set unnneeded ops to None
            pixel_loss, feature_loss, tv_loss, total_loss, train_op, global_step, learning_rate, summary_op = [
                None
            ] * 8

        # Put it all together
        encoder_decoder = EncoderDecoder(content_input=content_imgs,
                                         content_encoded=content_encoded,
                                         style_encoded=style_encoded_tensor,
                                         decoder_input=decoder_input,
                                         decoded=decoded,
                                         decoded_encoded=decoded_encoded,
                                         pixel_loss=pixel_loss,
                                         feature_loss=feature_loss,
                                         tv_loss=tv_loss,
                                         total_loss=total_loss,
                                         train_op=train_op,
                                         global_step=global_step,
                                         learning_rate=learning_rate,
                                         summary_op=summary_op)

        return encoder_decoder
Esempio n. 3
0
    def build_model(self, 
                    relu_target,
                    input_tensor,
                    style_encoded_tensor=None,
                    batch_size=8,
                    feature_weight=1,
                    pixel_weight=1,
                    tv_weight=0,
                    learning_rate=1e-4,
                    lr_decay=5e-5):
        '''Build the EncoderDecoder architecture for a given relu layer.

            Args:
                relu_target: Layer of VGG to decode from
                input_tensor: If None then a placeholder will be created, else use this tensor as the input to the encoder
                style_encoded_tensor: Tensor for style image features at the same relu layer. Used only at test time.
                batch_size: Batch size for training
                feature_weight: Float weight for feature reconstruction loss
                pixel_weight: Float weight for pixel reconstruction loss
                tv_weight: Float weight for total variation loss
                learning_rate: Float LR
                lr_decay: Float linear decay for training
            Returns:
                EncoderDecoder namedtuple with input/encoding/output tensors and ops for training.
        '''
        with tf.name_scope('encoder_decoder_'+relu_target):

            ### Build encoder for reluX_1
            with tf.name_scope('content_encoder_'+relu_target):
                if input_tensor is None:  
                    # This is the first level encoder that takes original content imgs
                    content_imgs = tf.placeholder_with_default(tf.constant([[[[0.,0.,0.]]]]), shape=(None, None, None, 3), name='content_imgs')
                else:                     
                    # This is an intermediate-level encoder that takes output tensor from previous level as input
                    content_imgs = input_tensor  

                # Build content layer encoding model
                content_layer = self.vgg_model.get_layer(relu_target).output
                content_encoder_model = Model(inputs=self.vgg_model.input, outputs=content_layer)

                # Setup content layer encodings for content images
                content_encoded = content_encoder_model(content_imgs)
 
            ### Build style encoder & WCT if test mode
            if self.mode != 'train':                
                # Apply WCT if flag is set to true. Otherwise, pass content_encoded along unchanged.
                with tf.name_scope('wct_'+relu_target):
                    decoder_input = tf.cond(self.apply_wct, lambda: wct_tf(content_encoded, style_encoded_tensor, self.alpha), lambda: content_encoded)
            else:
                decoder_input = content_encoded

            ### Build decoder
            with tf.name_scope('decoder_'+relu_target):
                n_channels = content_encoded.get_shape()[-1].value
                decoder_model = self.build_decoder(input_shape=(None, None, n_channels), relu_target=relu_target)

                # Wrap the decoder_input tensor so that it has the proper shape for decoder_model
                decoder_input_wrapped = tf.placeholder_with_default(decoder_input, shape=[None,None,None,n_channels])

                # Reconstruct/decode from encoding
                decoded = decoder_model(Lambda(lambda x: x)(decoder_input_wrapped)) # Lambda converts TF tensor to Keras

            # Content layer encoding for stylized out
            decoded_encoded = content_encoder_model(decoded)

        if self.mode == 'train':  # Train & summary ops only needed for training phase
            ### Losses
            with tf.name_scope('losses_'+relu_target):
                # Feature loss between encodings of original & reconstructed
                feature_loss = feature_weight * mse(decoded_encoded, content_encoded)

                # Pixel reconstruction loss between decoded/reconstructed img and original
                pixel_loss = pixel_weight * mse(decoded, content_imgs)

                # Total Variation loss
                if tv_weight > 0:
                    tv_loss = tv_weight * tf.reduce_mean(tf.image.total_variation(decoded))
                else:
                    tv_loss = tf.constant(0.)

                total_loss = feature_loss + pixel_loss + tv_loss

            ### Training ops
            with tf.name_scope('train_'+relu_target):
                global_step = tf.Variable(0, name='global_step_train', trainable=False)
                # self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, 0.96, staircase=False)
                learning_rate = torch_decay(learning_rate, global_step, lr_decay)
                d_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999)

                # Only train decoder vars, encoder is frozen
                d_vars = [var for var in tf.trainable_variables() if 'decoder_'+relu_target in var.name]

                train_op = d_optimizer.minimize(total_loss, var_list=d_vars, global_step=global_step)

            ### Loss & image summaries
            with tf.name_scope('summary_'+relu_target):
                feature_loss_summary = tf.summary.scalar('feature_loss', feature_loss)
                pixel_loss_summary = tf.summary.scalar('pixel_loss', pixel_loss)
                tv_loss_summary = tf.summary.scalar('tv_loss', tv_loss)
                total_loss_summary = tf.summary.scalar('total_loss', total_loss)

                content_imgs_summary = tf.summary.image('content_imgs', content_imgs)
                decoded_images_summary = tf.summary.image('decoded_images', clip(decoded))
                
                for var in d_vars:
                    tf.summary.histogram(var.op.name, var)

                summary_op = tf.summary.merge_all()
        else:
            # For inference set unnneeded ops to None
            pixel_loss, feature_loss, tv_loss, total_loss, train_op, global_step, learning_rate, summary_op = [None]*8

        # Put it all together
        encoder_decoder = EncoderDecoder(content_input=content_imgs, 
                                         content_encoder_model=content_encoder_model,
                                         content_encoded=content_encoded,
                                         style_encoded=style_encoded_tensor,
                                         decoder_input=decoder_input,
                                         decoder_model=decoder_model,
                                         decoded=decoded,
                                         decoded_encoded=decoded_encoded,
                                         pixel_loss=pixel_loss,
                                         feature_loss=feature_loss,
                                         tv_loss=tv_loss,
                                         total_loss=total_loss,
                                         train_op=train_op,
                                         global_step=global_step,
                                         learning_rate=learning_rate,
                                         summary_op=summary_op)
        
        return encoder_decoder
Esempio n. 4
0
    def build_train(self,
                    batch_size=8,
                    content_weight=1,
                    style_weight=1e-2,
                    tv_weight=0,
                    learning_rate=1e-4,
                    lr_decay=5e-5,
                    use_gram=False):
        ### Extract style layer feature maps for input style & decoded stylized output
        with tf.name_scope('style_layers'):
            # Build style model for blockX_conv1 tensors for X:[1,2,3,4]
            relu_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1']

            style_layers = [
                self.vgg_model.get_layer(l).output for l in relu_layers
            ]
            self.style_layer_model = Model(inputs=self.vgg_model.input,
                                           outputs=style_layers)

            self.style_fmaps = self.style_layer_model(self.style_imgs)
            self.decoded_fmaps = self.style_layer_model(self.decoded)

        ### Losses
        with tf.name_scope('losses'):
            # Content loss between stylized encoding and AdaIN encoding
            self.content_loss = content_weight * mse(self.decoded_encoded,
                                                     self.adain_encoded)

            # Style losses
            if not use_gram:  # Collect style losses for means/stds
                mean_std_losses = []
                for s_map, d_map in zip(self.style_fmaps, self.decoded_fmaps):
                    s_mean, s_var = tf.nn.moments(s_map, [1, 2])
                    d_mean, d_var = tf.nn.moments(d_map, [1, 2])
                    m_loss = sse(
                        d_mean,
                        s_mean) / batch_size  # normalized w.r.t. batch size
                    s_loss = sse(tf.sqrt(d_var), tf.sqrt(
                        s_var)) / batch_size  # normalized w.r.t. batch size

                    mean_std_loss = m_loss + s_loss
                    mean_std_loss = style_weight * mean_std_loss

                    mean_std_losses.append(mean_std_loss)

                self.style_loss = tf.reduce_sum(mean_std_losses)
            else:  # Use gram matrices for style loss instead
                gram_losses = []
                for s_map, d_map in zip(self.style_fmaps, self.decoded_fmaps):
                    s_gram = gram_matrix(s_map)
                    d_gram = gram_matrix(d_map)
                    gram_loss = mse(d_gram, s_gram)
                    gram_losses.append(gram_loss)
                self.style_loss = tf.reduce_sum(gram_losses) / batch_size

            # Total Variation loss
            if tv_weight > 0:
                self.tv_loss = tv_weight * tf.reduce_mean(
                    tf.image.total_variation(self.decoded))
            else:
                self.tv_loss = tf.constant(0.)

            # Add it all together
            self.total_loss = self.content_loss + self.style_loss + self.tv_loss

        ### Training ops
        with tf.name_scope('train'):
            self.global_step = tf.Variable(0,
                                           name='global_step_train',
                                           trainable=False)
            # self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, 0.96, staircase=False)
            self.learning_rate = torch_decay(learning_rate, self.global_step,
                                             lr_decay)
            d_optimizer = tf.train.AdamOptimizer(self.learning_rate,
                                                 beta1=0.9,
                                                 beta2=0.9)

            t_vars = tf.trainable_variables()
            self.d_vars = [var for var in t_vars if 'decoder' in var.name
                           ]  # Only train decoder vars, encoder is frozen

            self.train_op = d_optimizer.minimize(self.total_loss,
                                                 var_list=self.d_vars,
                                                 global_step=self.global_step)