Exemplo n.º 1
0
def ConvLSTMNet(input_feature, isTraining):
    cell_1 = BasicConvLSTMCell([32, 32],
                               64, [3, 3],
                               normalize=False,
                               is_training=isTraining)
    cell_2 = BasicConvLSTMCell([32, 32],
                               1, [3, 3],
                               last_activation=None,
                               normalize=False,
                               is_training=isTraining)
    cell_3 = BasicConvLSTMCell([32, 32],
                               1, [3, 3],
                               last_activation=tf.nn.tanh,
                               normalize=False,
                               is_training=isTraining)
    outputs1, state1 = tf.nn.dynamic_rnn(cell_1, input_feature, \
                                        initial_state=None, dtype=tf.float32, time_major=True, scope = 'cell_1')
    outputs2, state2 = tf.nn.dynamic_rnn(cell_2, outputs1, \
                                        initial_state=None, dtype=tf.float32, time_major=True, scope = 'cell_2')
    #outputs2 = tf.Print(outputs2, [outputs2])

    print('LSTM shape:', outputs2.shape)
    depth_split = tf.split(outputs2, num_or_size_splits=len_seq - 1, axis=1)
    depth_split_list = [tf.squeeze(x, axis=1) for x in depth_split]
    return depth_split_list
Exemplo n.º 2
0
    def generator(self, inputs, reuse=False, scope='g_net'):
        n, h, w, c = inputs.get_shape().as_list()

        if self.args.model == 'lstm':
            with tf.variable_scope('LSTM'):
                cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128)
                rnn_state = cell.zero_state(batch_size=self.batch_size,
                                            dtype=tf.float32)

        x_unwrap = []
        with tf.variable_scope(scope, reuse=reuse):
            with slim.arg_scope(
                [slim.conv2d, slim.conv2d_transpose],
                    activation_fn=tf.nn.relu,
                    padding='SAME',
                    normalizer_fn=None,
                    weights_initializer=tf.contrib.layers.xavier_initializer(
                        uniform=True),
                    biases_initializer=tf.constant_initializer(0.0)):

                inp_pred = inputs
                for i in xrange(self.n_levels):
                    scale = self.scale**(self.n_levels - i - 1)
                    hi = int(round(h * scale))
                    wi = int(round(w * scale))
                    inp_blur = tf.image.resize_images(inputs, [hi, wi],
                                                      method=0)
                    inp_pred = tf.stop_gradient(
                        tf.image.resize_images(inp_pred, [hi, wi], method=0))
                    inp_all = tf.concat([inp_blur, inp_pred],
                                        axis=3,
                                        name='inp')
                    if self.args.model == 'lstm':
                        rnn_state = tf.image.resize_images(rnn_state,
                                                           [hi // 4, wi // 4],
                                                           method=0)

                    # encoder
                    conv1_1 = slim.conv2d(inp_all, 32, [5, 5], scope='enc1_1')
                    conv1_2 = ResnetBlock(conv1_1, 32, 5, scope='enc1_2')
                    conv1_3 = ResnetBlock(conv1_2, 32, 5, scope='enc1_3')
                    conv1_4 = ResnetBlock(conv1_3, 32, 5, scope='enc1_4')
                    conv2_1 = slim.conv2d(conv1_4,
                                          64, [5, 5],
                                          stride=2,
                                          scope='enc2_1')
                    conv2_2 = ResnetBlock(conv2_1, 64, 5, scope='enc2_2')
                    conv2_3 = ResnetBlock(conv2_2, 64, 5, scope='enc2_3')
                    conv2_4 = ResnetBlock(conv2_3, 64, 5, scope='enc2_4')
                    conv3_1 = slim.conv2d(conv2_4,
                                          128, [5, 5],
                                          stride=2,
                                          scope='enc3_1')
                    conv3_2 = ResnetBlock(conv3_1, 128, 5, scope='enc3_2')
                    conv3_3 = ResnetBlock(conv3_2, 128, 5, scope='enc3_3')
                    conv3_4 = ResnetBlock(conv3_3, 128, 5, scope='enc3_4')

                    if self.args.model == 'lstm':
                        deconv3_4, rnn_state = cell(conv3_4, rnn_state)
                    else:
                        deconv3_4 = conv3_4

                    # decoder
                    deconv3_3 = ResnetBlock(deconv3_4, 128, 5, scope='dec3_3')
                    deconv3_2 = ResnetBlock(deconv3_3, 128, 5, scope='dec3_2')
                    deconv3_1 = ResnetBlock(deconv3_2, 128, 5, scope='dec3_1')
                    deconv2_4 = slim.conv2d_transpose(deconv3_1,
                                                      64, [4, 4],
                                                      stride=2,
                                                      scope='dec2_4')
                    cat2 = deconv2_4 + conv2_4
                    deconv2_3 = ResnetBlock(cat2, 64, 5, scope='dec2_3')
                    deconv2_2 = ResnetBlock(deconv2_3, 64, 5, scope='dec2_2')
                    deconv2_1 = ResnetBlock(deconv2_2, 64, 5, scope='dec2_1')
                    deconv1_4 = slim.conv2d_transpose(deconv2_1,
                                                      32, [4, 4],
                                                      stride=2,
                                                      scope='dec1_4')
                    cat1 = deconv1_4 + conv1_4
                    deconv1_3 = ResnetBlock(cat1, 32, 5, scope='dec1_3')
                    deconv1_2 = ResnetBlock(deconv1_3, 32, 5, scope='dec1_2')
                    deconv1_1 = ResnetBlock(deconv1_2, 32, 5, scope='dec1_1')
                    inp_pred = slim.conv2d(deconv1_1,
                                           self.chns, [5, 5],
                                           activation_fn=None,
                                           scope='dec1_0')

                    if i >= 0:
                        x_unwrap.append(inp_pred)
                    if i == 0:
                        tf.get_variable_scope().reuse_variables()

            return x_unwrap
Exemplo n.º 3
0
def generator(inputs, scope='g_net', n_levels=2):
    n, h, w, c = inputs.get_shape().as_list()

    x_unwrap = []
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        with slim.arg_scope(
            [slim.conv2d, slim.conv2d_transpose, slim.separable_conv2d],
                activation_fn=parametric_relu,
                padding='SAME',
                normalizer_fn=None,
                #     activation_fn=parametric_relu, padding='SAME', normalizer_fn=tf.layers.batch_normalization,
                weights_initializer=tf.contrib.layers.xavier_initializer(
                    uniform=True),
                biases_initializer=tf.constant_initializer(0.0)):

            # lstm = tf.keras.layers.ConvLSTM2D(filters=64, kernel_size=(1, 1), padding='same', return_sequences=True)
            cell = BasicConvLSTMCell([h / 8, w / 8], [1, 1], 64)
            rnn_state = cell.zero_state(batch_size=n, dtype=tf.float32)
            inp_pred = inputs
            for i in range(n_levels):
                scale = 0.5**(n_levels - i - 1)
                hi = int(round(h * scale))
                wi = int(round(w * scale))

                inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0)
                inp_pred = tf.image.resize_images(inp_pred, [hi, wi], method=0)
                inp_pred = tf.stop_gradient(inp_pred)
                inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp')
                rnn_state = tf.image.resize_images(rnn_state,
                                                   [hi // 8, wi // 8],
                                                   method=0)

                # encoder
                # conv1_1 = slim.separable_conv2d(inp_all, 32, [5, 5], scope='enc1_1_dw')

                print(inp_all)
                conv0 = slim.conv2d(inp_all, 8, [5, 5], scope='enc0')
                net = slim.conv2d(conv0, 16, [5, 5], stride=2, scope='enc1_1')
                conv1 = ResBottleneckBlock(net, 16, 5, scope='enc1_2')
                net = res_bottleneck_dsconv(conv1,
                                            32,
                                            5,
                                            stride=2,
                                            scope='enc2_1')
                net = ResBottleneckBlock(net, 32, 5, scope='enc2_2')
                net = ResBottleneckBlock(net, 32, 5, scope='enc2_3')
                conv2 = ResBottleneckBlock(net, 32, 5, scope='enc2_4')
                net = res_bottleneck_dsconv(conv2,
                                            64,
                                            5,
                                            stride=2,
                                            scope='enc3_1')
                net = ResBottleneckBlock(net, 64, 5, scope='enc3_2')
                net = ResBottleneckBlock(net, 64, 5, scope='enc3_3')
                net = ResBottleneckBlock(net, 64, 5, scope='enc3_4')
                net = ResBottleneckBlock(net, 64, 5, scope='enc3_5')
                net = ResBottleneckBlock(net, 64, 5, scope='enc3_6')

                net, rnn_state = cell(net, rnn_state)
                # net = lstm(net)
                # decoder
                net = ResBottleneckBlock(net, 64, 5, scope='dec3_6')
                net = ResBottleneckBlock(net, 64, 5, scope='dec3_5')
                net = ResBottleneckBlock(net, 64, 5, scope='dec3_4')
                net = ResBottleneckBlock(net, 64, 5, scope='dec3_3')
                net = ResBottleneckBlock(net, 64, 5, scope='dec3_2')
                net = slim.conv2d_transpose(net,
                                            32, [5, 5],
                                            stride=2,
                                            scope='dec3_1')
                net = net + conv2
                net = ResBottleneckBlock(net, 32, 5, scope='dec2_4')
                net = ResBottleneckBlock(net, 32, 5, scope='dec2_3')
                net = ResBottleneckBlock(net, 32, 5, scope='dec2_2')
                net = slim.conv2d_transpose(net,
                                            16, [5, 5],
                                            stride=2,
                                            scope='dec2_1')
                net = net + conv1
                net = ResBottleneckBlock(net, 16, 5, scope='dec1_2')
                net = slim.conv2d_transpose(net,
                                            8, [5, 5],
                                            stride=2,
                                            scope='dec1_1')
                net = net + conv0
                inp_pred = slim.conv2d(net,
                                       c, [5, 5],
                                       activation_fn=None,
                                       scope='dec0')

                x_unwrap.append(inp_pred)
        return x_unwrap
Exemplo n.º 4
0
    def generator(self, inputs, reuse=False):
        n, h, w, c = inputs.get_shape().as_list()
        n_feat = self.n_feat
        kernel_size = self.kernel_size
        scope = self.model

        x_unwrap = []
        if self.args.model == 'lstm':
            with tf.variable_scope('LSTM'):
                cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128)
                rnn_state = cell.zero_state(batch_size=self.batch_size,
                                            dtype=tf.float32)

        with tf.variable_scope(scope, reuse=reuse):
            with slim.arg_scope(
                [slim.conv2d, slim.conv2d_transpose],
                    activation_fn=tf.nn.relu,
                    padding='SAME',
                    normalizer_fn=None,
                    weights_initializer=tf.contrib.layers.xavier_initializer(
                        uniform=True),
                    biases_initializer=tf.constant_initializer(0.0)):

                inp_pred = inputs
                if self.model == 'SRN':
                    #x_unwrap = []
                    for i in xrange(self.n_levels):
                        scale = self.scale**(self.n_levels - i - 1)
                        hi = int(round(h * scale))
                        wi = int(round(w * scale))
                        inp_blur = tf.image.resize_images(inputs, [hi, wi],
                                                          method=0)
                        inp_pred = tf.stop_gradient(
                            tf.image.resize_images(inp_pred, [hi, wi],
                                                   method=0))
                        inp_all = tf.concat([inp_blur, inp_pred],
                                            axis=3,
                                            name='inp')
                        if self.args.model == 'lstm':
                            rnn_state = tf.image.resize_images(
                                rnn_state, [hi // 4, wi // 4], method=0)

                        eb1 = InBlock(inp_all,
                                      n_feat,
                                      kernel_size,
                                      num_resb=self.num_resb,
                                      scope='InBlock')

                        eb2 = EBlock(eb1,
                                     n_feat * 2,
                                     kernel_size,
                                     num_resb=self.num_resb,
                                     scope='eb2')

                        eb3 = EBlock(eb2,
                                     n_feat * 4,
                                     kernel_size,
                                     num_resb=self.num_resb,
                                     scope='eb3')

                        if self.args.model == 'lstm':
                            #deconv3_4, rnn_state = cell(conv3_4, rnn_state)
                            deconv3_4, rnn_state = cell(eb3, rnn_state)
                        else:
                            #deconv3_4 = conv3_4
                            deconv3_4 = eb3

                        db1 = DBlock(eb3, n_feat * 4, kernel_size, scope='db1')
                        cat2 = db1 + eb2

                        db2 = DBlock(cat2,
                                     n_feat * 2,
                                     kernel_size,
                                     scope='db2')
                        cat1 = db2 + eb1

                        inp_pred = OutBlock(cat1, n_feat, kernel_size)

                        if i >= 0:
                            x_unwrap.append(inp_pred)
                        if i == 0:
                            tf.get_variable_scope().reuse_variables()

                    return x_unwrap
                elif self.model == 'raw':
                    inp_pred = inputs
                    #x_unwrap = []
                    for i in xrange(self.n_levels):
                        scale = self.scale**(self.n_levels - i - 1)
                        hi = int(round(h * scale))
                        wi = int(round(w * scale))
                        inp_blur = tf.image.resize_images(inputs, [hi, wi],
                                                          method=0)
                        inp_pred = tf.stop_gradient(
                            tf.image.resize_images(inp_pred, [hi, wi],
                                                   method=0))
                        inp_all = tf.concat([inp_blur, inp_pred],
                                            axis=3,
                                            name='inp')
                        if self.args.model == 'lstm':
                            rnn_state = tf.image.resize_images(
                                rnn_state, [hi // 4, wi // 4], method=0)

                        # encoder
                        conv1_1 = slim.conv2d(inp_all,
                                              32, [5, 5],
                                              scope='enc1_1')
                        conv1_2 = ResnetBlock(conv1_1, 32, 5, scope='enc1_2')
                        conv1_3 = ResnetBlock(conv1_2, 32, 5, scope='enc1_3')
                        conv1_4 = ResnetBlock(conv1_3, 32, 5, scope='enc1_4')
                        conv2_1 = slim.conv2d(conv1_4,
                                              64, [5, 5],
                                              stride=2,
                                              scope='enc2_1')
                        conv2_2 = ResnetBlock(conv2_1, 64, 5, scope='enc2_2')
                        conv2_3 = ResnetBlock(conv2_2, 64, 5, scope='enc2_3')
                        conv2_4 = ResnetBlock(conv2_3, 64, 5, scope='enc2_4')
                        conv3_1 = slim.conv2d(conv2_4,
                                              128, [5, 5],
                                              stride=2,
                                              scope='enc3_1')
                        conv3_2 = ResnetBlock(conv3_1, 128, 5, scope='enc3_2')
                        conv3_3 = ResnetBlock(conv3_2, 128, 5, scope='enc3_3')
                        conv3_4 = ResnetBlock(conv3_3, 128, 5, scope='enc3_4')

                        if self.args.model == 'lstm':
                            deconv3_4, rnn_state = cell(conv3_4, rnn_state)
                        else:
                            deconv3_4 = conv3_4

                        # decoder
                        deconv3_3 = ResnetBlock(deconv3_4,
                                                128,
                                                5,
                                                scope='dec3_3')
                        deconv3_2 = ResnetBlock(deconv3_3,
                                                128,
                                                5,
                                                scope='dec3_2')
                        deconv3_1 = ResnetBlock(deconv3_2,
                                                128,
                                                5,
                                                scope='dec3_1')
                        deconv2_4 = slim.conv2d_transpose(deconv3_1,
                                                          64, [4, 4],
                                                          stride=2,
                                                          scope='dec2_4')
                        cat2 = deconv2_4 + conv2_4
                        deconv2_3 = ResnetBlock(cat2, 64, 5, scope='dec2_3')
                        deconv2_2 = ResnetBlock(deconv2_3,
                                                64,
                                                5,
                                                scope='dec2_2')
                        deconv2_1 = ResnetBlock(deconv2_2,
                                                64,
                                                5,
                                                scope='dec2_1')
                        deconv1_4 = slim.conv2d_transpose(deconv2_1,
                                                          32, [4, 4],
                                                          stride=2,
                                                          scope='dec1_4')
                        cat1 = deconv1_4 + conv1_4
                        deconv1_3 = ResnetBlock(cat1, 32, 5, scope='dec1_3')
                        deconv1_2 = ResnetBlock(deconv1_3,
                                                32,
                                                5,
                                                scope='dec1_2')
                        deconv1_1 = ResnetBlock(deconv1_2,
                                                32,
                                                5,
                                                scope='dec1_1')
                        inp_pred = slim.conv2d(deconv1_1,
                                               self.chns, [5, 5],
                                               activation_fn=None,
                                               scope='dec1_0')

                        if i >= 0:
                            x_unwrap.append(inp_pred)
                        if i == 0:
                            tf.get_variable_scope().reuse_variables()

                    return x_unwrap

                elif self.model == 'DAVANet':
                    #x_unwrap = []
                    conv1_1 = Conv(inputs, 8, ksize=[3, 3], scope='conv1_1')
                    conv1_2 = resnet_block(conv1_1,
                                           8,
                                           ksize=3,
                                           scope='conv1_2')

                    #downsample
                    conv2_1 = Conv(conv1_2,
                                   16,
                                   ksize=[3, 3],
                                   stride=2,
                                   scope='conv2_1')
                    conv2_2 = resnet_block(conv2_1,
                                           16,
                                           ksize=3,
                                           scope='conv2_2')

                    #downsample
                    conv3_1 = Conv(conv2_2,
                                   32,
                                   ksize=[3, 3],
                                   stride=2,
                                   scope='conv3_1')
                    conv3_2 = resnet_block(conv3_1,
                                           32,
                                           ksize=3,
                                           scope='conv3_2')

                    conv4_1 = Conv(conv3_2)

                    dilation = [1, 2, 3, 4]
                    convd_1 = resnet_block(conv3_2,
                                           32,
                                           ksize=3,
                                           dilation=[2, 1],
                                           scope='convd_1')
                    convd_2 = resnet_block(convd_1,
                                           32,
                                           ksize=3,
                                           dilation=[3, 1],
                                           scope='convd_2')
                    convd_3 = ms_dilate_block(convd_2,
                                              32,
                                              dilation=dilation,
                                              scope='convd_3')

                    #decode
                    upconv3_2 = Conv(convd_3,
                                     32,
                                     ksize=[3, 3],
                                     scope='upconv3_4')
                    upconv3_1 = resnet_block(upconv3_2,
                                             32,
                                             ksize=3,
                                             scope='upconv3_3')

                    #upsample
                    upconv2_u = upconv(upconv3_1, 16, scope='upconv2_u')
                    cat1 = tf.concat((upconv2_u, conv2_2), axis=3)
                    upconv2_4 = Conv(cat1, 16, ksize=[3, 3], scope='upconv2_4')
                    upconv2_3 = resnet_block(upconv2_4,
                                             16,
                                             ksize=3,
                                             scope='upconv2_3')

                    #upsample
                    upconv1_u = upconv(upconv2_3, 8, scope='upconv1_u')
                    cat0 = tf.concat((upconv1_u, conv1_2), axis=3)
                    upconv1_2 = Conv(cat0, 8, ksize=[3, 3], scope='upconv1_2')
                    upconv1_1 = resnet_block(upconv1_2,
                                             8,
                                             ksize=3,
                                             scope='upconv1_1')

                    inp_pred = Conv(upconv1_1, 3, ksize=[3, 3], scope='output')

                    return x_unwrap.append(inp_pred +
                                           inputs)  #inp_pred + inputs

                elif self.model == 'unet':

                    conv1_1 = slim.conv2d(inputs,
                                          8, [kernel_size, kernel_size],
                                          scope='enc1_1')
                    #conv1_4 = ResnetBlock(conv1_1, 8, kernel_size, scope='enc1_4')
                    conv1_4 = InvertedResidualBlock(conv1_1,
                                                    8,
                                                    expansion=2,
                                                    scope='enc1_4')

                    #conv2_1 = slim.conv2d(conv1_4, 16, [kernel_size, kernel_size], stride=2, scope='enc2_1')
                    conv2_1 = DepthwiseSeparableConvBlock(conv1_4,
                                                          16,
                                                          stride=2,
                                                          scope='enc2_1')
                    #conv2_4 = ResnetBlock(conv2_1, 16, kernel_size, scope='enc2_4')
                    conv2_4 = InvertedResidualBlock(conv2_1,
                                                    16,
                                                    expansion=2,
                                                    scope='enc2_4')

                    #conv3_1 = slim.conv2d(conv2_4, 32, [kernel_size, kernel_size], stride=2, scope='enc3_1')
                    conv3_1 = DepthwiseSeparableConvBlock(conv2_4,
                                                          32,
                                                          stride=2,
                                                          scope='enc3_1')
                    #conv3_4 = ResnetBlock(conv3_1, 32, kernel_size, scope='enc3_4')
                    conv3_4 = InvertedResidualBlock(conv3_1,
                                                    32,
                                                    expansion=4,
                                                    scope='enc3_4')

                    #conv4_1 = slim.conv2d(conv3_4, 48, [kernel_size, kernel_size], stride=2, scope='conv4_1')
                    conv4_1 = DepthwiseSeparableConvBlock(conv3_4,
                                                          48,
                                                          stride=2,
                                                          scope='enc4_1')
                    #conv4_4 = ResnetBlock(conv4_1, 48, kernel_size, scope='conv4_4')
                    conv4_4 = InvertedResidualBlock(conv4_1,
                                                    48,
                                                    expansion=4,
                                                    scope='enc4_4')

                    #conv5_1 = slim.conv2d(conv4_4, 64, [kernel_size, kernel_size], stride=2, scope='conv5_1')
                    #conv5_4 = ResnetBlock(conv5_1, 64, kernel_size, scope='conv5_4')
                    conv5_1 = DepthwiseSeparableConvBlock(conv4_4,
                                                          64,
                                                          stride=2,
                                                          scope='enc5_1')
                    conv5_4 = InvertedResidualBlock(conv5_1,
                                                    64,
                                                    expansion=4,
                                                    scope='enc5_4')

                    deconv5_4 = conv5_4
                    #                # decoder
                    #deconv5_3 = InvertedResidualBlock(deconv5_4, 64, expansion=4, scope='deconv5_3')
                    deconv5_0 = slim.conv2d_transpose(deconv5_4,
                                                      48, [4, 4],
                                                      stride=2,
                                                      scope='deconv5_0')
                    cat4 = deconv5_0 + conv4_4
                    deconv4_3 = InvertedResidualBlock(cat4,
                                                      48,
                                                      expansion=4,
                                                      scope='deconv4_3')
                    deconv4_0 = slim.conv2d_transpose(deconv4_3,
                                                      32, [4, 4],
                                                      stride=2,
                                                      scope='deconv4_0')
                    cat3 = deconv4_0 + conv3_4
                    deconv3_3 = InvertedResidualBlock(cat3,
                                                      32,
                                                      expansion=4,
                                                      scope='deconv3_3')
                    deconv3_0 = slim.conv2d_transpose(deconv3_3,
                                                      16, [4, 4],
                                                      stride=2,
                                                      scope='deconv3_0')
                    cat2 = deconv3_0 + conv2_4
                    deconv2_3 = InvertedResidualBlock(cat2,
                                                      16,
                                                      expansion=2,
                                                      scope='deconv2_3')
                    deconv2_0 = slim.conv2d_transpose(deconv2_3,
                                                      8, [4, 4],
                                                      stride=2,
                                                      scope='deconv2_0')
                    cat1 = deconv2_0 + conv1_4
                    deconv1_3 = InvertedResidualBlock(cat1,
                                                      8,
                                                      expansion=2,
                                                      scope='dec1_3')
                    inp_pred = slim.conv2d(deconv1_3,
                                           3, [kernel_size, kernel_size],
                                           activation_fn=slim.nn.sigmoid,
                                           scope='output')
                    return x_unwrap.append(inp_pred)

                elif self.model == 'DMPHN':
                    #x_unwrap = []
                    net = slim.conv2d(inputs,
                                      n_feat, [3, 3],
                                      activation_fn=None,
                                      scope='ec_conv1')
                    net = ResidualLinkBlock(net,
                                            n_feat,
                                            ksize=3,
                                            scope='ec_rlb1')
                    net = ResidualLinkBlock(net,
                                            n_feat,
                                            ksize=3,
                                            scope='ec_rlb2')

                    net = slim.conv2d(net,
                                      n_feat * 2, [3, 3],
                                      stride=2,
                                      activation_fn=None,
                                      scope='ec_conv2')
                    net = ResidualLinkBlock(net,
                                            n_feat * 2,
                                            ksize=3,
                                            scope='ec_rlb3')
                    net = ResidualLinkBlock(net,
                                            n_feat * 2,
                                            ksize=3,
                                            scope='ec_rlb4')

                    net = slim.conv2d(net,
                                      n_feat * 4, [3, 3],
                                      stride=2,
                                      activation_fn=None,
                                      scope='ec_conv3')
                    net = ResidualLinkBlock(net,
                                            n_feat * 4,
                                            ksize=3,
                                            scope='ec_rlb5')
                    net = ResidualLinkBlock(net,
                                            n_feat * 4,
                                            ksize=3,
                                            scope='ec_rlb6')

                    net = ResidualLinkBlock(net,
                                            n_feat * 4,
                                            ksize=3,
                                            scope='dc_rlb1')
                    net = ResidualLinkBlock(net,
                                            n_feat * 4,
                                            ksize=3,
                                            scope='dc_rlb2')

                    net = slim.conv2d_transpose(net,
                                                n_feat * 2, [4, 4],
                                                stride=2,
                                                activation_fn=None,
                                                scope='dc_deconv1')
                    net = ResidualLinkBlock(net,
                                            n_feat * 2,
                                            ksize=3,
                                            scope='dc_rlb3')
                    net = ResidualLinkBlock(net,
                                            n_feat * 2,
                                            ksize=3,
                                            scope='dc_flb4')

                    net = slim.conv2d_transpose(net,
                                                n_feat, [4, 4],
                                                stride=2,
                                                activation_fn=None,
                                                scope='dc_deconv2')
                    net = ResidualLinkBlock(net,
                                            n_feat,
                                            ksize=3,
                                            scope='dc_rlb5')
                    net = ResidualLinkBlock(net,
                                            n_feat,
                                            ksize=3,
                                            scope='dc_flb6')

                    net = slim.conv2d(net,
                                      3, [3, 3],
                                      activation_fn=None,
                                      scope='dc_conv1')

                    return x_unwrap.append(net)  #net

                elif self.model == 'DAVANet_light':
                    eb1 = InBlock(inputs,
                                  n_feat,
                                  kernel_size,
                                  num_resb=1,
                                  scope='InBlock')

                    eb2 = EBlock(eb1,
                                 n_feat * 2,
                                 kernel_size,
                                 num_resb=1,
                                 scope='eb1')

                    eb3 = EBlock(eb2,
                                 n_feat * 4,
                                 kernel_size,
                                 num_resb=1,
                                 scope='eb2')

                    context = ContextModule_lite(eb3, n_feat * 4)

                    db1 = DBlock(context,
                                 n_feat * 4,
                                 kernel_size,
                                 num_resb=1,
                                 scope='db1')
                    cat2 = db1 + eb2

                    db2 = DBlock(cat2,
                                 n_feat * 2,
                                 kernel_size,
                                 num_resb=1,
                                 scope='db2')
                    cat1 = db2 + eb1

                    inp_pred = OutBlock(cat1,
                                        n_feat,
                                        kernel_size,
                                        num_resb=1,
                                        scope='OutBlock')

                    return x_unwrap.append(inp_pred + inputs)

                elif self.model == 'DAVANet_dw':

                    eb1 = InBlock_dw(inputs,
                                     n_feat,
                                     num_resb=self.num_resb,
                                     expansion=2,
                                     scope='InBlock')

                    eb2 = EBlock_dw(eb1,
                                    n_feat * 2,
                                    num_resb=self.num_resb,
                                    expansion=4,
                                    scope='eb1')

                    eb3 = EBlock_dw(eb2,
                                    n_feat * 4,
                                    num_resb=self.num_resb,
                                    expansion=4,
                                    scope='eb2')

                    context = ContextModule_dwlite(eb3, n_feat * 4)

                    db1 = DBlock_dw(context,
                                    n_feat * 4,
                                    num_resb=self.num_resb,
                                    expansion=4,
                                    scope='db1')
                    cat2 = db1 + eb2

                    db2 = DBlock_dw(cat2,
                                    n_feat * 2,
                                    num_resb=self.num_resb,
                                    expansion=4,
                                    scope='db2')
                    cat1 = db2 + eb1

                    inp_pred = OutBlock_dw(cat1,
                                           n_feat,
                                           num_resb=self.num_resb,
                                           expansion=2,
                                           scope='OutBlock')

                    return x_unwrap.append(inp_pred + inputs)

                elif self.model == 'DFANet':
                    conv1 = slim.conv2d(inputs,
                                        8,
                                        kernel_size=[3, 3],
                                        stride=2,
                                        scope='conv1')

                elif self.model == 'Deblur_lite':
                    conv1 = slim.conv2d(inputs, 8, [3, 3], scope='conv1')
Exemplo n.º 5
0
    def generator(self,
                  inputs,
                  inputs_render,
                  coeff,
                  reuse=False,
                  scope='g_net'):
        n, h, w, c = inputs.get_shape().as_list()

        if self.args.model == 'lstm':
            with tf.variable_scope('LSTM'):
                cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128)
                rnn_state = cell.zero_state(batch_size=self.batch_size,
                                            dtype=tf.float32)

        # pre-handle coeff
        def pad_coeff(coeff_p, h_h, w_w):
            h_r = int(round((h_h - 9) * 0.5))
            w_r = int(round((w_w - 9) * 0.5))
            coeff_pm = tf.pad(
                coeff_p, [[0, 0], [h_r, h_h - 9 - h_r], [w_r, w_w - 9 - w_r]])
            coeff_pm = tf.expand_dims(coeff_pm, 3)
            coeff_pm = tf.cast(coeff_pm, tf.float32)

            return coeff_pm

        x_unwrap = []
        with tf.variable_scope(scope, reuse=reuse):
            with slim.arg_scope(
                [slim.conv2d, slim.conv2d_transpose],
                    activation_fn=tf.nn.relu,
                    padding='SAME',
                    normalizer_fn=None,
                    weights_initializer=tf.contrib.layers.xavier_initializer(
                        uniform=True),
                    biases_initializer=tf.constant_initializer(0.0)):

                inp_pred = inputs
                for i in xrange(self.n_levels):
                    scale = self.scale**(self.n_levels - i - 1)
                    hi = int(round(h * scale))
                    wi = int(round(w * scale))
                    inp_blur = tf.image.resize_images(inputs, [hi, wi],
                                                      method=0)
                    inp_pred = tf.stop_gradient(
                        tf.image.resize_images(inp_pred, [hi, wi], method=0))
                    inp_all = tf.concat([inp_blur, inp_pred],
                                        axis=3,
                                        name='inp')
                    if self.args.model == 'lstm':
                        rnn_state = tf.image.resize_images(rnn_state,
                                                           [hi // 4, wi // 4],
                                                           method=0)

                    # encoder
                    conv1_1 = slim.conv2d(inp_all, 32, [5, 5], scope='enc1_1')

                    conv1_1_c_nums = 32
                    # add render
                    if self.args.face == 'render' or self.args.face == 'both':
                        inp_render = tf.image.resize_images(inputs_render,
                                                            [hi, wi],
                                                            method=0)
                        conv1_1 = tf.concat([conv1_1, inp_render], axis=3)
                        conv1_1_c_nums = 35

                    conv1_2 = ResnetBlock(conv1_1,
                                          conv1_1_c_nums,
                                          5,
                                          scope='enc1_2')
                    conv1_3 = ResnetBlock(conv1_2,
                                          conv1_1_c_nums,
                                          5,
                                          scope='enc1_3')
                    conv1_4 = ResnetBlock(conv1_3,
                                          conv1_1_c_nums,
                                          5,
                                          scope='enc1_4')
                    conv2_1 = slim.conv2d(conv1_4,
                                          64, [5, 5],
                                          stride=2,
                                          scope='enc2_1')
                    conv2_2 = ResnetBlock(conv2_1, 64, 5, scope='enc2_2')
                    conv2_3 = ResnetBlock(conv2_2, 64, 5, scope='enc2_3')
                    conv2_4 = ResnetBlock(conv2_3, 64, 5, scope='enc2_4')
                    conv3_1 = slim.conv2d(conv2_4,
                                          128, [5, 5],
                                          stride=2,
                                          scope='enc3_1')
                    conv3_2 = ResnetBlock(conv3_1, 128, 5, scope='enc3_2')
                    conv3_3 = ResnetBlock(conv3_2, 128, 5, scope='enc3_3')
                    conv3_4 = ResnetBlock(conv3_3, 128, 5, scope='enc3_4')

                    if self.args.model == 'lstm':
                        deconv3_4, rnn_state = cell(conv3_4, rnn_state)
                    else:
                        deconv3_4 = conv3_4

                    # add coeff
                    channel_nums = 128
                    if self.args.face == 'coeff' or self.args.face == 'both':
                        n_c, h_c, w_c, c_c = deconv3_4.get_shape().as_list()
                        coeff_m = pad_coeff(coeff, h_c, w_c)
                        # coeff = tf.reshape(coeff,[n_c, 81])
                        # coeff = tf.cast(coeff,tf.float32)
                        # name = 'Fc_' + str(i)
                        # print(tf.get_variable_scope().reuse)
                        # coeff_m = tf.layers.dense(inputs=coeff, units=h_c*w_c, activation=None, name=name, reuse=tf.AUTO_REUSE)
                        # coeff_m = tf.reshape(coeff_m, [n_c, h_c, w_c])
                        # coeff_m=tf.expand_dims(coeff_m,axis=3)
                        # print(coeff_m.shape)
                        deconv3_4 = tf.concat([deconv3_4, coeff_m], axis=3)
                        channel_nums = 129

                    # decoder
                    deconv3_3 = ResnetBlock(deconv3_4,
                                            channel_nums,
                                            5,
                                            scope='dec3_3')
                    deconv3_2 = ResnetBlock(deconv3_3,
                                            channel_nums,
                                            5,
                                            scope='dec3_2')
                    deconv3_1 = ResnetBlock(deconv3_2,
                                            channel_nums,
                                            5,
                                            scope='dec3_1')
                    deconv2_4 = slim.conv2d_transpose(deconv3_1,
                                                      64, [4, 4],
                                                      stride=2,
                                                      scope='dec2_4')
                    cat2 = deconv2_4 + conv2_4
                    deconv2_3 = ResnetBlock(cat2, 64, 5, scope='dec2_3')
                    deconv2_2 = ResnetBlock(deconv2_3, 64, 5, scope='dec2_2')
                    deconv2_1 = ResnetBlock(deconv2_2, 64, 5, scope='dec2_1')
                    deconv1_4 = slim.conv2d_transpose(deconv2_1,
                                                      conv1_1_c_nums, [4, 4],
                                                      stride=2,
                                                      scope='dec1_4')
                    cat1 = deconv1_4 + conv1_4
                    deconv1_3 = ResnetBlock(cat1,
                                            conv1_1_c_nums,
                                            5,
                                            scope='dec1_3')
                    deconv1_2 = ResnetBlock(deconv1_3,
                                            conv1_1_c_nums,
                                            5,
                                            scope='dec1_2')
                    deconv1_1 = ResnetBlock(deconv1_2,
                                            conv1_1_c_nums,
                                            5,
                                            scope='dec1_1')
                    inp_pred = slim.conv2d(deconv1_1,
                                           self.chns, [5, 5],
                                           activation_fn=None,
                                           scope='dec1_0')

                    if i >= 0:
                        x_unwrap.append(inp_pred)
                    if i == 0:
                        tf.get_variable_scope().reuse_variables()

                    inp_pred_temp = inp_pred
                    for x in xrange(1, self.n_frames):
                        inp_pred = tf.concat([inp_pred, inp_pred_temp], axis=3)

            return x_unwrap