Example #1
0
    def build_decoder(self, input_shape, relu_target): 
        '''Build the decoder architecture that reconstructs from a given VGG relu layer.

            Args:
                input_shape: Tuple of input tensor shape, needed for channel dimension
                relu_target: Layer of VGG to decode from
        '''
        decoder_num = dict(zip(['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'], range(1,6)))[relu_target]

        # Dict specifying the layers for each decoder level. relu5_1 is the deepest decoder and will contain all layers
        decoder_archs = {
            5: [ #    layer    filts      HxW  / InC->OutC                                     
                (Conv2DReflect, 512),  # 16x16 / 512->512
                (UpSampling2D,),       # 16x16 -> 32x32
                (Conv2DReflect, 512),  # 32x32 / 512->512
                (Conv2DReflect, 512),  # 32x32 / 512->512
                (Conv2DReflect, 512)], # 32x32 / 512->512
            4: [
                (Conv2DReflect, 256),  # 32x32 / 512->256
                (UpSampling2D,),       # 32x32 -> 64x64
                (Conv2DReflect, 256),  # 64x64 / 256->256
                (Conv2DReflect, 256),  # 64x64 / 256->256
                (Conv2DReflect, 256)], # 64x64 / 256->256
            3: [
                (Conv2DReflect, 128),  # 64x64 / 256->128
                (UpSampling2D,),       # 64x64 -> 128x128
                (Conv2DReflect, 128)], # 128x128 / 128->128
            2: [
                (Conv2DReflect, 64),   # 128x128 / 128->64
                (UpSampling2D,)],      # 128x128 -> 256x256
            1: [
                (Conv2DReflect, 64)]   # 256x256 / 64->64
        }

        code = Input(shape=input_shape, name='decoder_input_'+relu_target)
        x = code

        ### Work backwards from deepest decoder # and build layer by layer
        decoders = reversed(range(1, decoder_num+1))
        count = 0        
        for d in decoders:
            for layer_tup in decoder_archs[d]:
                # Unique layer names are needed to ensure var naming consistency with multiple decoders in graph
                layer_name = '{}_{}'.format(relu_target, count)

                if layer_tup[0] == Conv2DReflect:
                    x = Conv2DReflect(layer_name, filters=layer_tup[1], kernel_size=3, padding='valid', activation='relu', name=layer_name)(x)
                elif layer_tup[0] == UpSampling2D:
                    x = UpSampling2D(name=layer_name)(x)
                
                count += 1

        layer_name = '{}_{}'.format(relu_target, count) 
        output = Conv2DReflect(layer_name, filters=3, kernel_size=3, padding='valid', activation=None, name=layer_name)(x)  # 256x256 / 64->3
        
        decoder_model = Model(code, output, name='decoder_model_'+relu_target)
        print(decoder_model.summary())
        return decoder_model
Example #2
0
    def build_decoder(self, input_shape):
        arch = [  #  HxW  / InC->OutC
            Conv2DReflect(256, 3, padding='valid',
                          activation='relu'),  # 32x32 / 512->256
            UpSampling2D(),  # 32x32 -> 64x64
            Conv2DReflect(256, 3, padding='valid',
                          activation='relu'),  # 64x64 / 256->256
            Conv2DReflect(256, 3, padding='valid',
                          activation='relu'),  # 64x64 / 256->256
            Conv2DReflect(256, 3, padding='valid',
                          activation='relu'),  # 64x64 / 256->256
            Conv2DReflect(128, 3, padding='valid',
                          activation='relu'),  # 64x64 / 256->128
            UpSampling2D(),  # 64x64 -> 128x128
            Conv2DReflect(128, 3, padding='valid',
                          activation='relu'),  # 128x128 / 128->128
            Conv2DReflect(64, 3, padding='valid',
                          activation='relu'),  # 128x128 / 128->64
            UpSampling2D(),  # 128x128 -> 256x256
            Conv2DReflect(64, 3, padding='valid',
                          activation='relu'),  # 256x256 / 64->64
            Conv2DReflect(3, 3, padding='valid', activation=None)
        ]  # 256x256 / 64->3

        code = Input(shape=input_shape, name='decoder_input')
        x = code

        with tf.variable_scope('decoder'):
            for layer in arch:
                x = layer(x)

        decoder = Model(code, x, name='decoder_model')
        print(decoder.summary())
        return decoder
Example #3
0
    def build_decoder(self, relu_target):
        '''Build the decoder architecture that reconstructs from a given VGG relu layer.

            Args:
                input_shape: Tuple of input tensor shape, needed for channel dimension
                relu_target: Layer of VGG to decode from
        '''
        input_shape = (256, 256, self.get_layer_channels_number(relu_target))
        decoder_num = dict(
            zip(['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'],
                range(1, 6)))[relu_target]

        # Dict specifying the layers for each decoder level. relu5_1 is the deepest decoder and will contain all layers
        middle_layers = {
            5: [  # layer    filts      HxW  / InC->OutC
                Conv2DRelu(512),  # 16x16 / 512->512
                UpSampling2D(),  # 16x16 -> 32x32
                Conv2DRelu(512),  # 32x32 / 512->512
                Conv2DRelu(512),  # 32x32 / 512->512
                Conv2DRelu(512)
            ],  # 32x32 / 512->512
            4: [
                Conv2DRelu(256),  # 32x32 / 512->256
                UpSampling2D(),  # 32x32 -> 64x64
                Conv2DRelu(256),  # 64x64 / 256->256
                Conv2DRelu(256),  # 64x64 / 256->256
                Conv2DRelu(256)
            ],  # 64x64 / 256->256
            3: [
                Conv2DRelu(128),  # 64x64 / 256->128
                UpSampling2D(),  # 64x64 -> 128x128
                Conv2DRelu(128)
            ],  # 128x128 / 128->128
            2: [
                Conv2DRelu(64),  # 128x128 / 128->64
                UpSampling2D()
            ],  # 128x128 -> 256x256
            1: [Conv2DRelu(64)]  # 256x256 / 64->64
        }

        middle_layers = [middle_layers[i] for i in range(decoder_num, 0, -1)]
        middle_layers = list(itertools.chain.from_iterable(middle_layers))

        return tf.keras.Sequential([
            Input(shape=input_shape), *middle_layers,
            Conv2DReflect(filters=3, activation=None)
        ],
                                   name=f'decoder_model_{relu_target}')
Example #4
0
    def build_decoder(self,
                      code,
                      input_shape,
                      relu_target,
                      encoder_indices=None,
                      use_wavelet_pooling=True):
        '''Build the decoder architecture that reconstructs from a given VGG relu layer.

            Args:
                input_shape: Tuple of input tensor shape, needed for channel dimension
                relu_target: Layer of VGG to decode from
        '''

        decoder_num = dict(
            zip(['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'],
                range(1, 6)))[relu_target]

        # Dict specifying the layers for each decoder level. relu5_1 is the deepest decoder and will contain all layers

        decoder_archs = {
            5: [  #    layer    filts      HxW  / InC->OutC
                (Conv2DReflect, 512),  # 16x16 / 512->512
                (UpSampling2D, 16, 512),  # 16x16 -> 32x32
                (Conv2DReflect, 512),  # 32x32 / 512->512
                (Conv2DReflect, 512),  # 32x32 / 512->512
                (Conv2DReflect, 512)
            ],  # 32x32 / 512->512
            4: [
                (Conv2DReflect, 256),  # 32x32 / 512->256
                (UpSampling2D, 32, 256),  # 32x32 -> 64x64
                (Conv2DReflect, 256),  # 64x64 / 256->256
                (Conv2DReflect, 256),  # 64x64 / 256->256
                (Conv2DReflect, 256)
            ],  # 64x64 / 256->256
            3: [
                (Conv2DReflect, 128),  # 64x64 / 256->128
                (UpSampling2D, 64, 128),  # 64x64 -> 128x128
                (Conv2DReflect, 128)
            ],  # 128x128 / 128->128
            2: [
                (Conv2DReflect, 64),  # 128x128 / 128->64
                (UpSampling2D, 128, 64)
            ],  # 128x128 -> 256x256
            1: [(Conv2DReflect, 64)]  # 256x256 / 64->64
        }

        #code = Input(shape=input_shape, name='decoder_input_'+relu_target)
        #code = tf.placeholder(shape=(None,) + input_shape, name='decoder_input_'+relu_target, dtype=tf.float32)
        x = code

        ### Work backwards from deepest decoder # and build layer by layer
        decoders = list(reversed(range(1, decoder_num + 1)))
        count = 0
        for d in decoders:
            for layer_tup in decoder_archs[d]:
                # Unique layer names are needed to ensure var naming consistency with multiple decoders in graph
                layer_name = '{}_{}'.format(relu_target, count)
                if layer_tup[0] == Conv2DReflect:
                    x = Conv2DReflect(x,
                                      layer_name,
                                      filters=layer_tup[1],
                                      kernel_size=(3, 3),
                                      padding='valid',
                                      activation=tf.nn.relu)
                elif layer_tup[0] == UpSampling2D:
                    #if d in [5, 4, 3, 2]:
                    #if d in [3, 2]:
                    if d in []:
                        hw = layer_tup[1]
                        x = tf.image.resize_images(x, size=(hw * 2, hw * 2))
                    else:
                        if use_wavelet_pooling:
                            hw = layer_tup[1]
                            c = layer_tup[2]

                            indice_list = list(
                                filter(
                                    lambda ind: int(ind.get_shape()[-1]) ==
                                    int(x.get_shape()[-1]), encoder_indices))
                            assert len(indice_list) == 1
                            indice = indice_list[0]
                            indice = tf.identity(indice,
                                                 name="indice_{}_{}_{}".format(
                                                     hw, hw, c))
                            h, w, c = (tf.shape(x)[1], tf.shape(x)[2],
                                       tf.shape(x)[3])
                            c = int(x.get_shape()[-1])

                            #### [batch, h, w, c]
                            LL = x
                            #### [4, batch, h, w, c]
                            LLHH = tf.concat([tf.expand_dims(LL, 0), indice],
                                             axis=0)

                            #### [batch, c, 4, h, w] -> [batch * c, 4, h, w]
                            flatten_4hw = tf.reshape(
                                tf.transpose(LLHH, [1, 4, 0, 2, 3]),
                                [-1, 4, h, w])

                            #### [batch * c, h1, w1]
                            unpooling_x = tf.map_fn(
                                lambda x: tf.py_func(single_unpooling_func,
                                                     inp=[x],
                                                     Tout=tf.float32),
                                flatten_4hw,
                                dtype=tf.float32)

                            h1, w1 = tf.shape(unpooling_x)[1], tf.shape(
                                unpooling_x)[2]
                            x = tf.transpose(
                                tf.reshape(unpooling_x, [-1, c, h1, w1]),
                                [0, 2, 3, 1])

                            x_pre = tf.image.resize_images(LL, size=(h1, w1))
                            x = tf.layers.conv2d(tf.concat([
                                x,
                                x_pre,
                            ],
                                                           axis=-1),
                                                 filters=int(
                                                     int(x.get_shape()[-1])),
                                                 kernel_size=(3, 3),
                                                 padding="SAME")
                        else:
                            hw = layer_tup[1]
                            c = layer_tup[2]
                            indice_list = list(
                                filter(
                                    lambda ind: int(ind.get_shape()[-1]) ==
                                    int(x.get_shape()[-1]), encoder_indices))
                            assert len(indice_list) == 1
                            indice = indice_list[0]
                            indice = tf.identity(indice,
                                                 name="indice_{}_{}_{}".format(
                                                     hw, hw, c))
                            h, w = (tf.shape(x)[1], tf.shape(x)[2])
                            x_pre = tf.image.resize_images(x,
                                                           size=(h * 2, w * 2))
                            x = partial(unpooling, h=h, w=w,
                                        c=c)([x_pre, indice])

                            x = tf.layers.conv2d(tf.concat([x, x_pre],
                                                           axis=-1),
                                                 filters=int(
                                                     int(x.get_shape()[-1])),
                                                 kernel_size=(1, 1),
                                                 padding="SAME")

                count += 1

        layer_name = '{}_{}'.format(relu_target, count)

        output = Conv2DReflect(x,
                               layer_name,
                               filters=3,
                               kernel_size=(3, 3),
                               padding='valid',
                               activation=None)

        return (code, output)