def __init__(self):
     self.clean_img = tf.placeholder(tf.float32, [None, None, None, IMG_C])
     self.noised_img = tf.placeholder(tf.float32, [None, None, None, IMG_C])
     self.train_phase = tf.placeholder(tf.bool)
     form_resnet = FormResNet("FormResNet")
     self.denoised_img, self.res = form_resnet(self.noised_img, self.train_phase)
     self.L_pix = tf.reduce_mean(tf.reduce_sum(tf.square(self.denoised_img - self.clean_img), [1, 2, 3]))
     # self.Phi = vgg16(tf.concat([self.denoised_img, self.denoised_img, self.denoised_img], 3))
     # self.Phi_ = vgg16(tf.concat([self.clean_img, self.clean_img, self.clean_img], 3))
     self.Phi = vgg16(self.denoised_img)
     self.Phi_ = vgg16(self.clean_img)
     self.L_feat = tf.reduce_mean(tf.square(self.Phi - self.Phi_))
     self.L_grad = tf.reduce_mean(tf.reduce_sum(tf.abs(sobel(self.denoised_img)[0] - sobel(self.clean_img)[0]) +\
                   tf.abs(sobel(self.denoised_img)[1] - sobel(self.clean_img)[1]), [1, 2, 3]))
     self.L_cross = (1 - ALPHA - BETA) * self.L_pix + ALPHA * self.L_feat + BETA * self.L_grad
     self.Opt = tf.train.AdamOptimizer(1e-4).minimize(self.L_cross)
     self.sess = tf.Session()
     saver = tf.train.Saver()
     # saver.restore(self.sess, "./save_para_3_sigma25_2/FormResNet25.ckpt")
     saver.restore(self.sess, "./sigma50_6000/FormResNet50.ckpt")
示例#2
0
    def __load_validate_input(self, image_size, img_paths):
        """
        :param image_size: 2-D list[int]
        :param img_paths: list[str]
        :return: None
        """
        img_rgb_path, theme_path, theme_mask_path, points_path, points_mask_path = img_paths
        self.img_test['img_rgb'] = self.__load_image(image_size, 'png', 3,
                                                     'img_rgb', img_rgb_path)
        self.img_test['img_lab'] = ops.rgb_to_lab(self.img_test['img_rgb'])
        self.img_test['img_l'] = \
            tf.reshape(self.img_test['img_lab'][:, :, :, 0] / 100. * 2 - 1, [1] + image_size + [1])
        _, self.img_test['img_l_grad'] = ops.sobel(self.img_test['img_l'])
        self.img_test['img_ab'] = (self.img_test['img_lab'][:, :, :, 1:] +
                                   128.) / 255. * 2 - 1
        _, self.img_test['img_ab_grad'] = ops.sobel(
            tf.concat([
                self.img_test['img_ab'][:, :, :, 0],
                self.img_test['img_ab'][:, :, :, 1]
            ],
                      axis=0))

        self.img_test['theme_rgb'] = self.__load_image([1, 7], 'png', 3,
                                                       'theme_rgb', theme_path)
        self.img_test['theme_lab'] = ops.rgb_to_lab(self.img_test['theme_rgb'])
        self.img_test['theme_ab'] = (self.img_test['theme_lab'][:, :, :, 1:] +
                                     128.) / 255. * 2 - 1
        self.img_test['theme_mask'] = self.__load_image([1, 7], 'png', 1,
                                                        'theme_mask',
                                                        theme_mask_path)

        self.img_test['points_rgb'] = self.__load_image(
            image_size, 'png', 3, 'points_rgb', points_path)
        self.img_test['points_lab'] = ops.rgb_to_lab(
            self.img_test['points_rgb'])
        self.img_test['points_ab'] = (
            self.img_test['points_lab'][:, :, :, 1:] + 128.) / 255. * 2 - 1
        self.img_test['points_mask'] = self.__load_image(
            image_size, 'png', 1, 'points_mask', points_mask_path)
 def __init__(self):
     self.clean_img = tf.placeholder(tf.float32, [None, None, None, IMG_C])
     self.noised_img = tf.placeholder(tf.float32, [None, None, None, IMG_C])
     self.train_phase = tf.placeholder(tf.bool)
     form_resnet = FormResNet("FormResNet")
     self.denoised_img, self.res = form_resnet(self.noised_img,
                                               self.train_phase)
     self.L_pix = tf.reduce_mean(
         tf.reduce_sum(tf.square(self.denoised_img - self.clean_img),
                       [1, 2, 3]))
     # self.Phi = vgg16(tf.concat([self.denoised_img, self.denoised_img, self.denoised_img], 3))
     # self.Phi_ = vgg16(tf.concat([self.clean_img, self.clean_img, self.clean_img], 3))
     self.Phi = vgg16(self.denoised_img)
     self.Phi_ = vgg16(self.clean_img)
     self.L_feat = tf.reduce_mean(tf.square(self.Phi - self.Phi_))
     self.L_grad = tf.reduce_mean(tf.reduce_sum(tf.abs(sobel(self.denoised_img)[0] - sobel(self.clean_img)[0]) +\
                   tf.abs(sobel(self.denoised_img)[1] - sobel(self.clean_img)[1]), [1, 2, 3]))
     self.L_cross = (
         1 - ALPHA -
         BETA) * self.L_pix + ALPHA * self.L_feat + BETA * self.L_grad
     self.Opt = tf.train.AdamOptimizer(1e-4).minimize(self.L_cross)
     self.sess = tf.Session()
     self.sess.run(tf.global_variables_initializer())
示例#4
0
 def __build_model(self, model, model_input, is_training, scope_name):
     """
     :param model: function handle
     :param model_input: dict
     :param is_training: bool
     :param scope_name: str
     :return: None
     """
     print('========== Building Model ==========')
     model_input['output_ab'] = model(
         img_l_batch=model_input['img_l'],
         img_l_gra_batch=model_input['img_l_grad'],
         theme_ab_batch=model_input['theme_ab'],
         theme_mask_batch=model_input['theme_mask'],
         local_ab_batch=model_input['points_ab'],
         local_mask_batch=model_input['points_mask'],
         is_training=is_training,
         scope_name=scope_name)
     _, model_input['output_ab_grad'] = ops.sobel(
         tf.concat([
             model_input['output_ab'][:, :, :, 0],
             model_input['output_ab'][:, :, :, 1]
         ],
                   axis=0))
     model_input['output_rgb'] = ops.lab_to_rgb(
         tf.concat([(model_input['img_l'] + 1.) / 2 * 100,
                    (model_input['output_ab'] + 1.) / 2 * 255 - 128],
                   axis=3))
     self.model_vars[scope_name] = [
         var for var in tf.global_variables()
         if var.name.startswith(scope_name)
     ]
     self.savers[scope_name] = tf.train.Saver(
         var_list=self.model_vars[scope_name])
     paras_count = tf.reduce_sum(
         [tf.reduce_prod(v.shape) for v in self.model_vars[scope_name]])
     print(
         '\033[0;35mModel \'%s\' load successfully, parameters in total: %8d\n \033[0m'
         % (scope_name, self.sess.run(paras_count)))
示例#5
0
    def __load_input_batch(self,
                           batch_size,
                           image_size,
                           is_random=True,
                           blank_rate=None):
        """
        :param batch_size: int
        :param image_size: 2-D list[int]
        :param is_random: bool
        :param blank_rate: 2-D list[double], rate of blank input
        :return: None
        """
        if len(self.input_file_list['img_rgb']) == len(self.input_file_list['color_map']) == \
                len(self.input_file_list['theme']) == len(self.input_file_list['theme_mask']) == \
                len(self.input_file_list['points']) == len(self.input_file_list['points_mask']) > 0:

            print('%8d samples in total.\n' %
                  len(self.input_file_list['img_rgb']))

            img_rgb = self.__load_image(image_size, 'png', 3, 'img_rgb')
            color_map = self.__load_image(image_size, 'png', 3, 'color_map')
            theme = self.__load_image([1, 7], 'png', 3, 'theme')
            theme_mask = self.__load_image([1, 7], 'png', 1, 'theme_mask')
            points = self.__load_image(image_size, 'png', 3, 'points')
            points_mask = self.__load_image(image_size, 'png', 1,
                                            'points_mask')

            # blank input
            color_map_blank = img_rgb
            theme_blank = tf.zeros([1, 7, 3], dtype=tf.float32)
            theme_mask_blank = tf.zeros([1, 7, 1], dtype=tf.float32)
            points_blank = tf.zeros(image_size + [3], dtype=tf.float32)
            points_mask_blank = tf.zeros(image_size + [1], dtype=tf.float32)

            def f1():  # only color theme
                return color_map, theme, theme_mask, points_blank, points_mask_blank

            def f2():  # only local points
                return color_map_blank, theme_blank, theme_mask_blank, points, points_mask

            def f3():  # color theme & local points
                return color_map, theme, theme_mask, points, points_mask

            rnd = tf.random_uniform(shape=[1],
                                    minval=0,
                                    maxval=1,
                                    dtype=tf.float32)
            rnd = rnd[0]
            if is_random and blank_rate:
                flag1 = tf.less(rnd, blank_rate[0])
                flag2 = tf.logical_and(tf.greater_equal(rnd, blank_rate[0]),
                                       tf.less(rnd, sum(blank_rate)))
                flag3 = tf.greater_equal(rnd, sum(blank_rate))
                color_map, theme, theme_mask, points, points_mask = \
                    tf.case({flag1: f1, flag2: f2, flag3: f3}, exclusive=True)

            # original input
            if is_random:
                self.model_input['img_rgb'], self.model_input['color_map_rgb'], \
                self.model_input['theme_rgb'], self.model_input['theme_mask'], \
                self.model_input['points_rgb'], self.model_input['points_mask'] = \
                    tf.train.shuffle_batch(tensors=[img_rgb, color_map, theme, theme_mask, points, points_mask],
                                           batch_size=batch_size,
                                           capacity=1000,
                                           min_after_dequeue=500,
                                           num_threads=4)
            else:
                self.model_input['img_rgb'], self.model_input['color_map_rgb'], \
                self.model_input['theme_rgb'], self.model_input['theme_mask'], \
                self.model_input['points_rgb'], self.model_input['points_mask'] = \
                    tf.train.batch(tensors=[img_rgb, color_map, theme, theme_mask, points, points_mask],
                                   batch_size=1,
                                   capacity=500,
                                   num_threads=1)

            # convert to lab color space
            self.model_input['img_lab'] = ops.rgb_to_lab(
                self.model_input['img_rgb'])
            self.model_input['img_l'] = \
                tf.reshape(self.model_input['img_lab'][:, :, :, 0] / 100. * 2 - 1, [batch_size] + image_size + [1])
            _, self.model_input['img_l_grad'] = ops.sobel(
                self.model_input['img_l'])
            self.model_input['img_ab'] = (
                self.model_input['img_lab'][:, :, :, 1:] + 128.) / 255. * 2 - 1
            _, self.model_input['img_ab_grad'] = ops.sobel(
                tf.concat([
                    self.model_input['img_ab'][:, :, :, 0],
                    self.model_input['img_ab'][:, :, :, 1]
                ],
                          axis=0))

            self.model_input['color_map_lab'] = ops.rgb_to_lab(
                self.model_input['color_map_rgb'])
            self.model_input['color_map_ab'] = (
                self.model_input['color_map_lab'][:, :, :, 1:] +
                128.) / 255. * 2 - 1

            self.model_input['theme_lab'] = ops.rgb_to_lab(
                self.model_input['theme_rgb'])
            self.model_input['theme_ab'] = (
                self.model_input['theme_lab'][:, :, :, 1:] +
                128.) / 255. * 2 - 1

            self.model_input['points_lab'] = ops.rgb_to_lab(
                self.model_input['points_rgb'])
            self.model_input['points_ab'] = (
                self.model_input['points_lab'][:, :, :, 1:] +
                128.) / 255. * 2 - 1

        else:
            raise AssertionError(
                'length of input list is zero or not consistent:\n'
                'img_rgb_list: %8d\n'
                'color_map_list: %8d\n'
                'theme_list: %8d\n'
                'theme_mask_list: %8d\n'
                'points_list: %8d\n'
                'points_mask_list: %8d\n' %
                (len(self.input_file_list['img_rgb']),
                 len(self.input_file_list['color_map']),
                 len(self.input_file_list['theme']),
                 len(self.input_file_list['theme_mask']),
                 len(self.input_file_list['points']),
                 len(self.input_file_list['points_mask'])))