def ConvLSTMNet(input_feature, isTraining): cell_1 = BasicConvLSTMCell([32, 32], 64, [3, 3], normalize=False, is_training=isTraining) cell_2 = BasicConvLSTMCell([32, 32], 1, [3, 3], last_activation=None, normalize=False, is_training=isTraining) cell_3 = BasicConvLSTMCell([32, 32], 1, [3, 3], last_activation=tf.nn.tanh, normalize=False, is_training=isTraining) outputs1, state1 = tf.nn.dynamic_rnn(cell_1, input_feature, \ initial_state=None, dtype=tf.float32, time_major=True, scope = 'cell_1') outputs2, state2 = tf.nn.dynamic_rnn(cell_2, outputs1, \ initial_state=None, dtype=tf.float32, time_major=True, scope = 'cell_2') #outputs2 = tf.Print(outputs2, [outputs2]) print('LSTM shape:', outputs2.shape) depth_split = tf.split(outputs2, num_or_size_splits=len_seq - 1, axis=1) depth_split_list = [tf.squeeze(x, axis=1) for x in depth_split] return depth_split_list
def generator(self, inputs, reuse=False, scope='g_net'): n, h, w, c = inputs.get_shape().as_list() if self.args.model == 'lstm': with tf.variable_scope('LSTM'): cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128) rnn_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32) x_unwrap = [] with tf.variable_scope(scope, reuse=reuse): with slim.arg_scope( [slim.conv2d, slim.conv2d_transpose], activation_fn=tf.nn.relu, padding='SAME', normalizer_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer( uniform=True), biases_initializer=tf.constant_initializer(0.0)): inp_pred = inputs for i in xrange(self.n_levels): scale = self.scale**(self.n_levels - i - 1) hi = int(round(h * scale)) wi = int(round(w * scale)) inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0) inp_pred = tf.stop_gradient( tf.image.resize_images(inp_pred, [hi, wi], method=0)) inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp') if self.args.model == 'lstm': rnn_state = tf.image.resize_images(rnn_state, [hi // 4, wi // 4], method=0) # encoder conv1_1 = slim.conv2d(inp_all, 32, [5, 5], scope='enc1_1') conv1_2 = ResnetBlock(conv1_1, 32, 5, scope='enc1_2') conv1_3 = ResnetBlock(conv1_2, 32, 5, scope='enc1_3') conv1_4 = ResnetBlock(conv1_3, 32, 5, scope='enc1_4') conv2_1 = slim.conv2d(conv1_4, 64, [5, 5], stride=2, scope='enc2_1') conv2_2 = ResnetBlock(conv2_1, 64, 5, scope='enc2_2') conv2_3 = ResnetBlock(conv2_2, 64, 5, scope='enc2_3') conv2_4 = ResnetBlock(conv2_3, 64, 5, scope='enc2_4') conv3_1 = slim.conv2d(conv2_4, 128, [5, 5], stride=2, scope='enc3_1') conv3_2 = ResnetBlock(conv3_1, 128, 5, scope='enc3_2') conv3_3 = ResnetBlock(conv3_2, 128, 5, scope='enc3_3') conv3_4 = ResnetBlock(conv3_3, 128, 5, scope='enc3_4') if self.args.model == 'lstm': deconv3_4, rnn_state = cell(conv3_4, rnn_state) else: deconv3_4 = conv3_4 # decoder deconv3_3 = ResnetBlock(deconv3_4, 128, 5, scope='dec3_3') deconv3_2 = ResnetBlock(deconv3_3, 128, 5, scope='dec3_2') deconv3_1 = ResnetBlock(deconv3_2, 128, 5, scope='dec3_1') deconv2_4 = slim.conv2d_transpose(deconv3_1, 64, [4, 4], stride=2, scope='dec2_4') cat2 = deconv2_4 + conv2_4 deconv2_3 = ResnetBlock(cat2, 64, 5, scope='dec2_3') deconv2_2 = ResnetBlock(deconv2_3, 64, 5, scope='dec2_2') deconv2_1 = ResnetBlock(deconv2_2, 64, 5, scope='dec2_1') deconv1_4 = slim.conv2d_transpose(deconv2_1, 32, [4, 4], stride=2, scope='dec1_4') cat1 = deconv1_4 + conv1_4 deconv1_3 = ResnetBlock(cat1, 32, 5, scope='dec1_3') deconv1_2 = ResnetBlock(deconv1_3, 32, 5, scope='dec1_2') deconv1_1 = ResnetBlock(deconv1_2, 32, 5, scope='dec1_1') inp_pred = slim.conv2d(deconv1_1, self.chns, [5, 5], activation_fn=None, scope='dec1_0') if i >= 0: x_unwrap.append(inp_pred) if i == 0: tf.get_variable_scope().reuse_variables() return x_unwrap
def generator(inputs, scope='g_net', n_levels=2): n, h, w, c = inputs.get_shape().as_list() x_unwrap = [] with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): with slim.arg_scope( [slim.conv2d, slim.conv2d_transpose, slim.separable_conv2d], activation_fn=parametric_relu, padding='SAME', normalizer_fn=None, # activation_fn=parametric_relu, padding='SAME', normalizer_fn=tf.layers.batch_normalization, weights_initializer=tf.contrib.layers.xavier_initializer( uniform=True), biases_initializer=tf.constant_initializer(0.0)): # lstm = tf.keras.layers.ConvLSTM2D(filters=64, kernel_size=(1, 1), padding='same', return_sequences=True) cell = BasicConvLSTMCell([h / 8, w / 8], [1, 1], 64) rnn_state = cell.zero_state(batch_size=n, dtype=tf.float32) inp_pred = inputs for i in range(n_levels): scale = 0.5**(n_levels - i - 1) hi = int(round(h * scale)) wi = int(round(w * scale)) inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0) inp_pred = tf.image.resize_images(inp_pred, [hi, wi], method=0) inp_pred = tf.stop_gradient(inp_pred) inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp') rnn_state = tf.image.resize_images(rnn_state, [hi // 8, wi // 8], method=0) # encoder # conv1_1 = slim.separable_conv2d(inp_all, 32, [5, 5], scope='enc1_1_dw') print(inp_all) conv0 = slim.conv2d(inp_all, 8, [5, 5], scope='enc0') net = slim.conv2d(conv0, 16, [5, 5], stride=2, scope='enc1_1') conv1 = ResBottleneckBlock(net, 16, 5, scope='enc1_2') net = res_bottleneck_dsconv(conv1, 32, 5, stride=2, scope='enc2_1') net = ResBottleneckBlock(net, 32, 5, scope='enc2_2') net = ResBottleneckBlock(net, 32, 5, scope='enc2_3') conv2 = ResBottleneckBlock(net, 32, 5, scope='enc2_4') net = res_bottleneck_dsconv(conv2, 64, 5, stride=2, scope='enc3_1') net = ResBottleneckBlock(net, 64, 5, scope='enc3_2') net = ResBottleneckBlock(net, 64, 5, scope='enc3_3') net = ResBottleneckBlock(net, 64, 5, scope='enc3_4') net = ResBottleneckBlock(net, 64, 5, scope='enc3_5') net = ResBottleneckBlock(net, 64, 5, scope='enc3_6') net, rnn_state = cell(net, rnn_state) # net = lstm(net) # decoder net = ResBottleneckBlock(net, 64, 5, scope='dec3_6') net = ResBottleneckBlock(net, 64, 5, scope='dec3_5') net = ResBottleneckBlock(net, 64, 5, scope='dec3_4') net = ResBottleneckBlock(net, 64, 5, scope='dec3_3') net = ResBottleneckBlock(net, 64, 5, scope='dec3_2') net = slim.conv2d_transpose(net, 32, [5, 5], stride=2, scope='dec3_1') net = net + conv2 net = ResBottleneckBlock(net, 32, 5, scope='dec2_4') net = ResBottleneckBlock(net, 32, 5, scope='dec2_3') net = ResBottleneckBlock(net, 32, 5, scope='dec2_2') net = slim.conv2d_transpose(net, 16, [5, 5], stride=2, scope='dec2_1') net = net + conv1 net = ResBottleneckBlock(net, 16, 5, scope='dec1_2') net = slim.conv2d_transpose(net, 8, [5, 5], stride=2, scope='dec1_1') net = net + conv0 inp_pred = slim.conv2d(net, c, [5, 5], activation_fn=None, scope='dec0') x_unwrap.append(inp_pred) return x_unwrap
def generator(self, inputs, reuse=False): n, h, w, c = inputs.get_shape().as_list() n_feat = self.n_feat kernel_size = self.kernel_size scope = self.model x_unwrap = [] if self.args.model == 'lstm': with tf.variable_scope('LSTM'): cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128) rnn_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32) with tf.variable_scope(scope, reuse=reuse): with slim.arg_scope( [slim.conv2d, slim.conv2d_transpose], activation_fn=tf.nn.relu, padding='SAME', normalizer_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer( uniform=True), biases_initializer=tf.constant_initializer(0.0)): inp_pred = inputs if self.model == 'SRN': #x_unwrap = [] for i in xrange(self.n_levels): scale = self.scale**(self.n_levels - i - 1) hi = int(round(h * scale)) wi = int(round(w * scale)) inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0) inp_pred = tf.stop_gradient( tf.image.resize_images(inp_pred, [hi, wi], method=0)) inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp') if self.args.model == 'lstm': rnn_state = tf.image.resize_images( rnn_state, [hi // 4, wi // 4], method=0) eb1 = InBlock(inp_all, n_feat, kernel_size, num_resb=self.num_resb, scope='InBlock') eb2 = EBlock(eb1, n_feat * 2, kernel_size, num_resb=self.num_resb, scope='eb2') eb3 = EBlock(eb2, n_feat * 4, kernel_size, num_resb=self.num_resb, scope='eb3') if self.args.model == 'lstm': #deconv3_4, rnn_state = cell(conv3_4, rnn_state) deconv3_4, rnn_state = cell(eb3, rnn_state) else: #deconv3_4 = conv3_4 deconv3_4 = eb3 db1 = DBlock(eb3, n_feat * 4, kernel_size, scope='db1') cat2 = db1 + eb2 db2 = DBlock(cat2, n_feat * 2, kernel_size, scope='db2') cat1 = db2 + eb1 inp_pred = OutBlock(cat1, n_feat, kernel_size) if i >= 0: x_unwrap.append(inp_pred) if i == 0: tf.get_variable_scope().reuse_variables() return x_unwrap elif self.model == 'raw': inp_pred = inputs #x_unwrap = [] for i in xrange(self.n_levels): scale = self.scale**(self.n_levels - i - 1) hi = int(round(h * scale)) wi = int(round(w * scale)) inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0) inp_pred = tf.stop_gradient( tf.image.resize_images(inp_pred, [hi, wi], method=0)) inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp') if self.args.model == 'lstm': rnn_state = tf.image.resize_images( rnn_state, [hi // 4, wi // 4], method=0) # encoder conv1_1 = slim.conv2d(inp_all, 32, [5, 5], scope='enc1_1') conv1_2 = ResnetBlock(conv1_1, 32, 5, scope='enc1_2') conv1_3 = ResnetBlock(conv1_2, 32, 5, scope='enc1_3') conv1_4 = ResnetBlock(conv1_3, 32, 5, scope='enc1_4') conv2_1 = slim.conv2d(conv1_4, 64, [5, 5], stride=2, scope='enc2_1') conv2_2 = ResnetBlock(conv2_1, 64, 5, scope='enc2_2') conv2_3 = ResnetBlock(conv2_2, 64, 5, scope='enc2_3') conv2_4 = ResnetBlock(conv2_3, 64, 5, scope='enc2_4') conv3_1 = slim.conv2d(conv2_4, 128, [5, 5], stride=2, scope='enc3_1') conv3_2 = ResnetBlock(conv3_1, 128, 5, scope='enc3_2') conv3_3 = ResnetBlock(conv3_2, 128, 5, scope='enc3_3') conv3_4 = ResnetBlock(conv3_3, 128, 5, scope='enc3_4') if self.args.model == 'lstm': deconv3_4, rnn_state = cell(conv3_4, rnn_state) else: deconv3_4 = conv3_4 # decoder deconv3_3 = ResnetBlock(deconv3_4, 128, 5, scope='dec3_3') deconv3_2 = ResnetBlock(deconv3_3, 128, 5, scope='dec3_2') deconv3_1 = ResnetBlock(deconv3_2, 128, 5, scope='dec3_1') deconv2_4 = slim.conv2d_transpose(deconv3_1, 64, [4, 4], stride=2, scope='dec2_4') cat2 = deconv2_4 + conv2_4 deconv2_3 = ResnetBlock(cat2, 64, 5, scope='dec2_3') deconv2_2 = ResnetBlock(deconv2_3, 64, 5, scope='dec2_2') deconv2_1 = ResnetBlock(deconv2_2, 64, 5, scope='dec2_1') deconv1_4 = slim.conv2d_transpose(deconv2_1, 32, [4, 4], stride=2, scope='dec1_4') cat1 = deconv1_4 + conv1_4 deconv1_3 = ResnetBlock(cat1, 32, 5, scope='dec1_3') deconv1_2 = ResnetBlock(deconv1_3, 32, 5, scope='dec1_2') deconv1_1 = ResnetBlock(deconv1_2, 32, 5, scope='dec1_1') inp_pred = slim.conv2d(deconv1_1, self.chns, [5, 5], activation_fn=None, scope='dec1_0') if i >= 0: x_unwrap.append(inp_pred) if i == 0: tf.get_variable_scope().reuse_variables() return x_unwrap elif self.model == 'DAVANet': #x_unwrap = [] conv1_1 = Conv(inputs, 8, ksize=[3, 3], scope='conv1_1') conv1_2 = resnet_block(conv1_1, 8, ksize=3, scope='conv1_2') #downsample conv2_1 = Conv(conv1_2, 16, ksize=[3, 3], stride=2, scope='conv2_1') conv2_2 = resnet_block(conv2_1, 16, ksize=3, scope='conv2_2') #downsample conv3_1 = Conv(conv2_2, 32, ksize=[3, 3], stride=2, scope='conv3_1') conv3_2 = resnet_block(conv3_1, 32, ksize=3, scope='conv3_2') conv4_1 = Conv(conv3_2) dilation = [1, 2, 3, 4] convd_1 = resnet_block(conv3_2, 32, ksize=3, dilation=[2, 1], scope='convd_1') convd_2 = resnet_block(convd_1, 32, ksize=3, dilation=[3, 1], scope='convd_2') convd_3 = ms_dilate_block(convd_2, 32, dilation=dilation, scope='convd_3') #decode upconv3_2 = Conv(convd_3, 32, ksize=[3, 3], scope='upconv3_4') upconv3_1 = resnet_block(upconv3_2, 32, ksize=3, scope='upconv3_3') #upsample upconv2_u = upconv(upconv3_1, 16, scope='upconv2_u') cat1 = tf.concat((upconv2_u, conv2_2), axis=3) upconv2_4 = Conv(cat1, 16, ksize=[3, 3], scope='upconv2_4') upconv2_3 = resnet_block(upconv2_4, 16, ksize=3, scope='upconv2_3') #upsample upconv1_u = upconv(upconv2_3, 8, scope='upconv1_u') cat0 = tf.concat((upconv1_u, conv1_2), axis=3) upconv1_2 = Conv(cat0, 8, ksize=[3, 3], scope='upconv1_2') upconv1_1 = resnet_block(upconv1_2, 8, ksize=3, scope='upconv1_1') inp_pred = Conv(upconv1_1, 3, ksize=[3, 3], scope='output') return x_unwrap.append(inp_pred + inputs) #inp_pred + inputs elif self.model == 'unet': conv1_1 = slim.conv2d(inputs, 8, [kernel_size, kernel_size], scope='enc1_1') #conv1_4 = ResnetBlock(conv1_1, 8, kernel_size, scope='enc1_4') conv1_4 = InvertedResidualBlock(conv1_1, 8, expansion=2, scope='enc1_4') #conv2_1 = slim.conv2d(conv1_4, 16, [kernel_size, kernel_size], stride=2, scope='enc2_1') conv2_1 = DepthwiseSeparableConvBlock(conv1_4, 16, stride=2, scope='enc2_1') #conv2_4 = ResnetBlock(conv2_1, 16, kernel_size, scope='enc2_4') conv2_4 = InvertedResidualBlock(conv2_1, 16, expansion=2, scope='enc2_4') #conv3_1 = slim.conv2d(conv2_4, 32, [kernel_size, kernel_size], stride=2, scope='enc3_1') conv3_1 = DepthwiseSeparableConvBlock(conv2_4, 32, stride=2, scope='enc3_1') #conv3_4 = ResnetBlock(conv3_1, 32, kernel_size, scope='enc3_4') conv3_4 = InvertedResidualBlock(conv3_1, 32, expansion=4, scope='enc3_4') #conv4_1 = slim.conv2d(conv3_4, 48, [kernel_size, kernel_size], stride=2, scope='conv4_1') conv4_1 = DepthwiseSeparableConvBlock(conv3_4, 48, stride=2, scope='enc4_1') #conv4_4 = ResnetBlock(conv4_1, 48, kernel_size, scope='conv4_4') conv4_4 = InvertedResidualBlock(conv4_1, 48, expansion=4, scope='enc4_4') #conv5_1 = slim.conv2d(conv4_4, 64, [kernel_size, kernel_size], stride=2, scope='conv5_1') #conv5_4 = ResnetBlock(conv5_1, 64, kernel_size, scope='conv5_4') conv5_1 = DepthwiseSeparableConvBlock(conv4_4, 64, stride=2, scope='enc5_1') conv5_4 = InvertedResidualBlock(conv5_1, 64, expansion=4, scope='enc5_4') deconv5_4 = conv5_4 # # decoder #deconv5_3 = InvertedResidualBlock(deconv5_4, 64, expansion=4, scope='deconv5_3') deconv5_0 = slim.conv2d_transpose(deconv5_4, 48, [4, 4], stride=2, scope='deconv5_0') cat4 = deconv5_0 + conv4_4 deconv4_3 = InvertedResidualBlock(cat4, 48, expansion=4, scope='deconv4_3') deconv4_0 = slim.conv2d_transpose(deconv4_3, 32, [4, 4], stride=2, scope='deconv4_0') cat3 = deconv4_0 + conv3_4 deconv3_3 = InvertedResidualBlock(cat3, 32, expansion=4, scope='deconv3_3') deconv3_0 = slim.conv2d_transpose(deconv3_3, 16, [4, 4], stride=2, scope='deconv3_0') cat2 = deconv3_0 + conv2_4 deconv2_3 = InvertedResidualBlock(cat2, 16, expansion=2, scope='deconv2_3') deconv2_0 = slim.conv2d_transpose(deconv2_3, 8, [4, 4], stride=2, scope='deconv2_0') cat1 = deconv2_0 + conv1_4 deconv1_3 = InvertedResidualBlock(cat1, 8, expansion=2, scope='dec1_3') inp_pred = slim.conv2d(deconv1_3, 3, [kernel_size, kernel_size], activation_fn=slim.nn.sigmoid, scope='output') return x_unwrap.append(inp_pred) elif self.model == 'DMPHN': #x_unwrap = [] net = slim.conv2d(inputs, n_feat, [3, 3], activation_fn=None, scope='ec_conv1') net = ResidualLinkBlock(net, n_feat, ksize=3, scope='ec_rlb1') net = ResidualLinkBlock(net, n_feat, ksize=3, scope='ec_rlb2') net = slim.conv2d(net, n_feat * 2, [3, 3], stride=2, activation_fn=None, scope='ec_conv2') net = ResidualLinkBlock(net, n_feat * 2, ksize=3, scope='ec_rlb3') net = ResidualLinkBlock(net, n_feat * 2, ksize=3, scope='ec_rlb4') net = slim.conv2d(net, n_feat * 4, [3, 3], stride=2, activation_fn=None, scope='ec_conv3') net = ResidualLinkBlock(net, n_feat * 4, ksize=3, scope='ec_rlb5') net = ResidualLinkBlock(net, n_feat * 4, ksize=3, scope='ec_rlb6') net = ResidualLinkBlock(net, n_feat * 4, ksize=3, scope='dc_rlb1') net = ResidualLinkBlock(net, n_feat * 4, ksize=3, scope='dc_rlb2') net = slim.conv2d_transpose(net, n_feat * 2, [4, 4], stride=2, activation_fn=None, scope='dc_deconv1') net = ResidualLinkBlock(net, n_feat * 2, ksize=3, scope='dc_rlb3') net = ResidualLinkBlock(net, n_feat * 2, ksize=3, scope='dc_flb4') net = slim.conv2d_transpose(net, n_feat, [4, 4], stride=2, activation_fn=None, scope='dc_deconv2') net = ResidualLinkBlock(net, n_feat, ksize=3, scope='dc_rlb5') net = ResidualLinkBlock(net, n_feat, ksize=3, scope='dc_flb6') net = slim.conv2d(net, 3, [3, 3], activation_fn=None, scope='dc_conv1') return x_unwrap.append(net) #net elif self.model == 'DAVANet_light': eb1 = InBlock(inputs, n_feat, kernel_size, num_resb=1, scope='InBlock') eb2 = EBlock(eb1, n_feat * 2, kernel_size, num_resb=1, scope='eb1') eb3 = EBlock(eb2, n_feat * 4, kernel_size, num_resb=1, scope='eb2') context = ContextModule_lite(eb3, n_feat * 4) db1 = DBlock(context, n_feat * 4, kernel_size, num_resb=1, scope='db1') cat2 = db1 + eb2 db2 = DBlock(cat2, n_feat * 2, kernel_size, num_resb=1, scope='db2') cat1 = db2 + eb1 inp_pred = OutBlock(cat1, n_feat, kernel_size, num_resb=1, scope='OutBlock') return x_unwrap.append(inp_pred + inputs) elif self.model == 'DAVANet_dw': eb1 = InBlock_dw(inputs, n_feat, num_resb=self.num_resb, expansion=2, scope='InBlock') eb2 = EBlock_dw(eb1, n_feat * 2, num_resb=self.num_resb, expansion=4, scope='eb1') eb3 = EBlock_dw(eb2, n_feat * 4, num_resb=self.num_resb, expansion=4, scope='eb2') context = ContextModule_dwlite(eb3, n_feat * 4) db1 = DBlock_dw(context, n_feat * 4, num_resb=self.num_resb, expansion=4, scope='db1') cat2 = db1 + eb2 db2 = DBlock_dw(cat2, n_feat * 2, num_resb=self.num_resb, expansion=4, scope='db2') cat1 = db2 + eb1 inp_pred = OutBlock_dw(cat1, n_feat, num_resb=self.num_resb, expansion=2, scope='OutBlock') return x_unwrap.append(inp_pred + inputs) elif self.model == 'DFANet': conv1 = slim.conv2d(inputs, 8, kernel_size=[3, 3], stride=2, scope='conv1') elif self.model == 'Deblur_lite': conv1 = slim.conv2d(inputs, 8, [3, 3], scope='conv1')
def generator(self, inputs, inputs_render, coeff, reuse=False, scope='g_net'): n, h, w, c = inputs.get_shape().as_list() if self.args.model == 'lstm': with tf.variable_scope('LSTM'): cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128) rnn_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32) # pre-handle coeff def pad_coeff(coeff_p, h_h, w_w): h_r = int(round((h_h - 9) * 0.5)) w_r = int(round((w_w - 9) * 0.5)) coeff_pm = tf.pad( coeff_p, [[0, 0], [h_r, h_h - 9 - h_r], [w_r, w_w - 9 - w_r]]) coeff_pm = tf.expand_dims(coeff_pm, 3) coeff_pm = tf.cast(coeff_pm, tf.float32) return coeff_pm x_unwrap = [] with tf.variable_scope(scope, reuse=reuse): with slim.arg_scope( [slim.conv2d, slim.conv2d_transpose], activation_fn=tf.nn.relu, padding='SAME', normalizer_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer( uniform=True), biases_initializer=tf.constant_initializer(0.0)): inp_pred = inputs for i in xrange(self.n_levels): scale = self.scale**(self.n_levels - i - 1) hi = int(round(h * scale)) wi = int(round(w * scale)) inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0) inp_pred = tf.stop_gradient( tf.image.resize_images(inp_pred, [hi, wi], method=0)) inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp') if self.args.model == 'lstm': rnn_state = tf.image.resize_images(rnn_state, [hi // 4, wi // 4], method=0) # encoder conv1_1 = slim.conv2d(inp_all, 32, [5, 5], scope='enc1_1') conv1_1_c_nums = 32 # add render if self.args.face == 'render' or self.args.face == 'both': inp_render = tf.image.resize_images(inputs_render, [hi, wi], method=0) conv1_1 = tf.concat([conv1_1, inp_render], axis=3) conv1_1_c_nums = 35 conv1_2 = ResnetBlock(conv1_1, conv1_1_c_nums, 5, scope='enc1_2') conv1_3 = ResnetBlock(conv1_2, conv1_1_c_nums, 5, scope='enc1_3') conv1_4 = ResnetBlock(conv1_3, conv1_1_c_nums, 5, scope='enc1_4') conv2_1 = slim.conv2d(conv1_4, 64, [5, 5], stride=2, scope='enc2_1') conv2_2 = ResnetBlock(conv2_1, 64, 5, scope='enc2_2') conv2_3 = ResnetBlock(conv2_2, 64, 5, scope='enc2_3') conv2_4 = ResnetBlock(conv2_3, 64, 5, scope='enc2_4') conv3_1 = slim.conv2d(conv2_4, 128, [5, 5], stride=2, scope='enc3_1') conv3_2 = ResnetBlock(conv3_1, 128, 5, scope='enc3_2') conv3_3 = ResnetBlock(conv3_2, 128, 5, scope='enc3_3') conv3_4 = ResnetBlock(conv3_3, 128, 5, scope='enc3_4') if self.args.model == 'lstm': deconv3_4, rnn_state = cell(conv3_4, rnn_state) else: deconv3_4 = conv3_4 # add coeff channel_nums = 128 if self.args.face == 'coeff' or self.args.face == 'both': n_c, h_c, w_c, c_c = deconv3_4.get_shape().as_list() coeff_m = pad_coeff(coeff, h_c, w_c) # coeff = tf.reshape(coeff,[n_c, 81]) # coeff = tf.cast(coeff,tf.float32) # name = 'Fc_' + str(i) # print(tf.get_variable_scope().reuse) # coeff_m = tf.layers.dense(inputs=coeff, units=h_c*w_c, activation=None, name=name, reuse=tf.AUTO_REUSE) # coeff_m = tf.reshape(coeff_m, [n_c, h_c, w_c]) # coeff_m=tf.expand_dims(coeff_m,axis=3) # print(coeff_m.shape) deconv3_4 = tf.concat([deconv3_4, coeff_m], axis=3) channel_nums = 129 # decoder deconv3_3 = ResnetBlock(deconv3_4, channel_nums, 5, scope='dec3_3') deconv3_2 = ResnetBlock(deconv3_3, channel_nums, 5, scope='dec3_2') deconv3_1 = ResnetBlock(deconv3_2, channel_nums, 5, scope='dec3_1') deconv2_4 = slim.conv2d_transpose(deconv3_1, 64, [4, 4], stride=2, scope='dec2_4') cat2 = deconv2_4 + conv2_4 deconv2_3 = ResnetBlock(cat2, 64, 5, scope='dec2_3') deconv2_2 = ResnetBlock(deconv2_3, 64, 5, scope='dec2_2') deconv2_1 = ResnetBlock(deconv2_2, 64, 5, scope='dec2_1') deconv1_4 = slim.conv2d_transpose(deconv2_1, conv1_1_c_nums, [4, 4], stride=2, scope='dec1_4') cat1 = deconv1_4 + conv1_4 deconv1_3 = ResnetBlock(cat1, conv1_1_c_nums, 5, scope='dec1_3') deconv1_2 = ResnetBlock(deconv1_3, conv1_1_c_nums, 5, scope='dec1_2') deconv1_1 = ResnetBlock(deconv1_2, conv1_1_c_nums, 5, scope='dec1_1') inp_pred = slim.conv2d(deconv1_1, self.chns, [5, 5], activation_fn=None, scope='dec1_0') if i >= 0: x_unwrap.append(inp_pred) if i == 0: tf.get_variable_scope().reuse_variables() inp_pred_temp = inp_pred for x in xrange(1, self.n_frames): inp_pred = tf.concat([inp_pred, inp_pred_temp], axis=3) return x_unwrap