def deconv2d_layer_concat(x, name, W_s, concat_x, output_shape=None, stride=2, stddev=0.02, if_relu=False): ''' Deconv2d operator for U-Net concat. Args: x: inputs W_s: shape of weight output_shape: shape after deconv2d ''' if output_shape == None: x_shape = tf.shape(x) output_shape = tf.stack( [x_shape[0], x_shape[1] * 2, x_shape[2] * 2, x_shape[3] // 2]) W_t = utils.weight_variable(W_s, stddev=stddev, name='W_' + name) b_t = utils.bias_variable([W_s[2]], name='b_' + name) #conv_t = utils.conv2d_transpose_strided_valid(x, W_t, b_t, output_shape, stride) conv_t = utils.conv2d_transpose_strided(x, W_t, b_t, output_shape, stride) if if_relu: conv_t = tf.nn.relu(conv_t, name=name + '_relu') conv_concat = utils.crop_and_concat(concat_x, conv_t) return conv_concat
def deconv2d_layer(x, name, W_s, output_shape=None, stride=2): '''Deconv2d operator Args: x: inputs W_s: shape of weight output_shape: shape after deconv2d ''' W_t = utils.weight_variable(W_s, name='W_' + name) b_t = utils.bias_variable([W_s[2]], name='b_' + name) conv_t = utils.conv2d_transpose_strided(x, W_t, b_t, output_shape, stride) print('conv_%s: ' % name, conv_t.get_shape()) return conv_t