Exemplo n.º 1
0
class DerainNet:
    model_name = 'ReHEN'
    '''Derain Net: all the implemented layer are included (e.g. SEBlock,
                                                                HEU,
                                                                REU,
                                                                ReHEB).
        Params:
            config: the training configuration
            sess: runing session
    '''
    def __init__(self, config, sess=None):
        # config proto
        self.config = config
        self.channel_dim = self.config.channel_dim
        self.batch_size = self.config.batch_size
        self.patch_size = self.config.patch_size
        self.input_channels = self.config.input_channels

        # metrics
        self.ssim = SSIM(max_val=1.0)
        self.psnr = PSNR(max_val=1.0)

        # create session
        self.sess = sess

    # global average pooling
    def globalAvgPool2D(self, input_x):
        global_avgpool2d = tf.contrib.keras.layers.GlobalAvgPool2D()
        return global_avgpool2d(input_x)

    # leaky relu
    def leakyRelu(self, input_x):
        leaky_relu = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)
        return leaky_relu(input_x)

    # squeeze-and-excitation block
    def SEBlock(self, input_x, input_dim=32, reduce_dim=8, scope='SEBlock'):
        with tf.variable_scope(scope) as scope:
            # global scale
            global_pl = self.globalAvgPool2D(input_x)
            reduce_fc1 = slim.fully_connected(global_pl,
                                              reduce_dim,
                                              activation_fn=tf.nn.relu)
            reduce_fc2 = slim.fully_connected(reduce_fc1,
                                              input_dim,
                                              activation_fn=None)
            g_scale = tf.nn.sigmoid(reduce_fc2)
            g_scale = tf.expand_dims(g_scale, axis=1)
            g_scale = tf.expand_dims(g_scale, axis=1)
            gs_input = input_x * g_scale
            return gs_input

    # recurrent enhancement unit
    def REU(self, input_x, h, out_dim, scope='REU'):
        with tf.variable_scope(scope):
            if h is None:
                self.conv_xz = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xz')
                self.conv_xn = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xn')
                z = tf.nn.sigmoid(self.conv_xz)
                f = tf.nn.tanh(self.conv_xn)
                h = z * f
            else:
                self.conv_hz = slim.conv2d(h, out_dim, 3, 1, scope='conv_hz')
                self.conv_hr = slim.conv2d(h, out_dim, 3, 1, scope='conv_hr')

                self.conv_xz = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xz')
                self.conv_xr = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xr')
                self.conv_xn = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xn')
                r = tf.nn.sigmoid(self.conv_hr + self.conv_xr)
                z = tf.nn.sigmoid(self.conv_hz + self.conv_xz)

                self.conv_hn = slim.conv2d(r * h,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_hn')
                n = tf.nn.tanh(self.conv_xn + self.conv_hn)
                h = (1 - z) * h + z * n

        # channel attention block
        se = self.SEBlock(h, out_dim, reduce_dim=int(out_dim / 4))
        h = self.leakyRelu(se)
        return h, h

    # hierarchy enhancement unit
    def HEU(self, input_x, is_training=False, scope='HEU'):
        with tf.variable_scope(scope) as scope:
            local_shortcut = input_x
            dense_shortcut = input_x

            for i in range(1, 3):
                with tf.variable_scope('ResBlock_{}'.format(i)):
                    with tf.variable_scope('Conv1'):
                        conv_tmp1 = slim.conv2d(local_shortcut,
                                                self.channel_dim, 3, 1)
                        conv_tmp1_bn = bn(conv_tmp1, is_training,
                                          UPDATE_G_OPS_COLLECTION)
                        out_tmp1 = tf.nn.relu(conv_tmp1_bn)

                    with tf.variable_scope('Conv2'):
                        conv_tmp2 = slim.conv2d(out_tmp1, self.channel_dim, 3,
                                                1)
                        conv_tmp2_bn = bn(conv_tmp2, is_training,
                                          UPDATE_G_OPS_COLLECTION)
                        out_tmp2 = tf.nn.relu(conv_tmp2_bn)
                        conv_shortcut = tf.add(local_shortcut, out_tmp2)

                dense_shortcut = tf.concat([dense_shortcut, conv_shortcut], -1)
                local_shortcut = conv_shortcut

            with tf.variable_scope('Trans'):
                conv_tmp3 = slim.conv2d(dense_shortcut, self.channel_dim, 3, 1)
                conv_tmp3_bn = bn(conv_tmp3, is_training,
                                  UPDATE_G_OPS_COLLECTION)
                conv_tmp3_se = self.SEBlock(conv_tmp3_bn,
                                            self.channel_dim,
                                            reduce_dim=int(self.channel_dim /
                                                           4))
                out_tmp3 = tf.nn.relu(conv_tmp3_se)
                heu_f = tf.add(input_x, out_tmp3)

            return heu_f

    # recurrent hierarchy enhancement block
    def ReHEB(self, input_x, h, is_training=False, scope='ReHEB'):
        with tf.variable_scope(scope):
            if input_x.get_shape().as_list()[-1] == 3:
                heu = input_x
            else:
                heu = self.HEU(input_x, is_training=is_training)
            reheb, h = self.REU(heu, h, out_dim=self.channel_dim)
        return reheb, h

    # recurrent hierarchy and enhancement network
    def derainNet(self, input_x, is_training=False, scope_name='derainNet'):
        '''ReHEN: recurrent hierarchy and enhancement network
            Params:
                input_x: input data
                is_training: training phase or testing phase
                scope_name: the scope name of the ReHEN (customer definition, default='derainNet')
            Return:
                return the derained results

            Input shape:
                4D tensor with shape '(batch_size, height, width, channels)'
                
            Output shape:
                4D tensor with shape '(batch_size, height, width, channels)'
        '''
        # reuse: tf.AUTO_REUSE(such setting will enable the network to reuse parameters automatically)
        with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
            # convert is_training variable to tensor type
            is_training = tf.convert_to_tensor(is_training,
                                               dtype='bool',
                                               name='is_training')
            with slim.arg_scope(
                [slim.conv2d, slim.conv2d_transpose],
                    weights_initializer=tf.contrib.layers.xavier_initializer(),
                    normalizer_fn=None,
                    activation_fn=None,
                    padding='SAME'):

                stages = 4
                block_num = 5
                old_states = [None for _ in range(block_num)]
                oups = []
                ori = input_x
                shallow_f = input_x

                for stg in range(stages):
                    # recurrent hierarchy enhancement block (ReHEB)
                    with tf.variable_scope('ReHEB'):
                        states = []
                        for i in range(block_num):
                            sp = 'ReHEB_{}'.format(i)
                            shallow_f, st = self.ReHEB(shallow_f,
                                                       old_states[i],
                                                       is_training=is_training,
                                                       scope=sp)
                            states.append(st)

                    further_f = shallow_f

                    # residual map generator (RMG)
                    with tf.variable_scope('RMG'):
                        rm_conv = slim.conv2d(further_f, self.channel_dim, 3,
                                              1)
                        rm_conv_se = self.SEBlock(rm_conv,
                                                  self.channel_dim,
                                                  reduce_dim=int(
                                                      self.channel_dim / 4))
                        rm_conv_a = self.leakyRelu(rm_conv_se)
                        neg_residual_conv = slim.conv2d(
                            rm_conv_a, self.input_channels, 3, 1)
                        neg_residual = neg_residual_conv
                    shallow_f = ori - neg_residual
                    oups.append(shallow_f)
                    old_states = [tf.identity(s) for s in states]

        return oups, shallow_f, neg_residual

    def build(self):
        # placeholder
        self.rain = tf.placeholder(tf.float32,
                                   [None, None, None, self.input_channels],
                                   name='rain')
        self.norain = tf.placeholder(tf.float32,
                                     [None, None, None, self.input_channels],
                                     name='norain')
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')

        # derainnet
        self.oups, self.out, self.residual = self.derainNet(
            self.rain, is_training=self.config.is_training)
        self.finer_out = tf.clip_by_value(self.out, 0, 1.0)
        self.finer_residual = tf.clip_by_value(tf.abs(self.residual), 0, 1)

        # metrics
        self.ssim_finer_tensor = tf.reduce_mean(
            self.ssim._ssim(self.norain, self.out, 0, 0))
        self.psnr_finer_tensor = tf.reduce_mean(
            self.psnr.compute_psnr(self.norain, self.out))
        self.ssim_val = tf.reduce_mean(
            self.ssim._ssim(self.norain, self.finer_out, 0, 0))
        self.psnr_val = tf.reduce_mean(
            self.psnr.compute_psnr(self.norain, self.finer_out))

        # loss function
        # MSE loss
        self.l2_loss = tf.reduce_sum([
            tf.reduce_mean(tf.square(out - self.norain)) for out in self.oups
        ])
        # SSIM loss
        self.ssim_loss = tf.log(1.0 / (self.ssim_finer_tensor + 1e-5))
        # PSNR loss
        self.psnr_loss = 1.0 / (self.psnr_finer_tensor + 1e-3)
        # total loss
        self.total_loss = self.l2_loss + 0.001 * self.ssim_loss + 0.1 * self.psnr_loss

        # optimization
        t_vars = tf.trainable_variables()
        g_vars = [var for var in t_vars if 'derainNet' in var.name]
        loss_train_ops = tf.train.AdamOptimizer(
            learning_rate=self.lr,
            beta1=self.config.beta1,
            beta2=self.config.beta2).minimize(self.total_loss, var_list=g_vars)

        # batchnorm training ops
        batchnorm_ops = tf.get_collection(UPDATE_G_OPS_COLLECTION)
        bn_update_ops = tf.group(*batchnorm_ops)
        self.train_ops = tf.group(loss_train_ops, bn_update_ops)

        # summary
        self.l2_loss_summary = tf.summary.scalar('l2_loss', self.l2_loss)
        self.total_loss_summary = tf.summary.scalar('total_loss',
                                                    self.total_loss)
        self.edge_loss_summary = tf.summary.scalar('ssim_loss', self.ssim_loss)
        self.edge_loss_summary = tf.summary.scalar('psnr_loss', self.psnr_loss)
        self.ssim_summary = tf.summary.scalar('ssim', self.ssim_val)
        self.psnr_summary = tf.summary.scalar('psnr', self.psnr_val)
        self.summaries = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(self.config.logs_dir,
                                                    self.sess.graph)

        # saver
        global_variables = tf.global_variables()
        var_to_store = [
            var for var in global_variables if 'derainNet' in var.name
        ]
        self.saver = tf.train.Saver(var_list=var_to_store)

        # trainable variables
        num_params = 0
        for var in g_vars:
            tmp_num = 1
            for i in var.get_shape().as_list():
                tmp_num = tmp_num * i
            num_params = num_params + tmp_num
        print('numbers of trainable parameters:{}'.format(num_params))

    # training phase
    def train(self):
        # initialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        # load training model
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            print('[***] fail to load model')

        lr_ = self.config.lr
        start_time = time.time()
        for counter in range(self.config.iterations):
            if counter == 50000:
                lr_ = 0.1 * lr_

            # obtain training image pairs
            img, label = read_data(self.config.train_dataset,
                                   self.config.data_path, self.batch_size,
                                   self.patch_size, self.config.trainset_size)
            _, total_loss, summaries, ssim, psnr = self.sess.run(
                [
                    self.train_ops, self.total_loss, self.summaries,
                    self.ssim_val, self.psnr_val
                ],
                feed_dict={
                    self.rain: img,
                    self.norain: label,
                    self.lr: lr_
                })

            print(
                'Iteration:{}, phase:{}, loss:{:.4f}, ssim:{:.4f}, psnr:{:.4f}, lr:{}, iterations:{}'
                .format(counter, self.config.phase, total_loss, ssim, psnr,
                        lr_, self.config.iterations))

            self.summary_writer.add_summary(summaries, global_step=counter)
            if np.mod(counter, 100) == 0:
                self.sample(self.config.sample_dir, counter)

            if np.mod(counter, 500) == 0:
                self.save_model()

        # save final model
        if counter == self.config.iterations - 1:
            self.save_model()

        # training time
        end_time = time.time()
        print('training time:{} hours'.format(
            (end_time - start_time) / 3600.0))

    # sampling phase
    def sample(self, sample_dir, iterations):
        # obtaining sampling image pairs
        test_img, test_label = read_data(self.config.test_dataset,
                                         self.config.data_path,
                                         self.batch_size, self.patch_size,
                                         self.config.testset_size)
        finer_out, finer_residual = self.sess.run(
            [self.finer_out, self.finer_residual],
            feed_dict={self.rain: test_img})

        # save sampling images
        test_img_uint8 = np.uint8(test_img * 255.0)
        test_label_uint8 = np.uint8(test_label * 255.0)
        finer_out_uint8 = np.uint8(finer_out * 255.0)
        finer_residual = np.uint8(finer_residual * 255.0)
        sample = np.concatenate([
            test_img_uint8, test_label_uint8, finer_out_uint8, finer_residual
        ], 2)
        save_images(
            sample, [
                int(np.sqrt(self.batch_size)) + 1,
                int(np.sqrt(self.batch_size)) + 1
            ], '{}/{}_{}_{:04d}.jpg'.format(self.config.sample_dir,
                                            self.config.test_dataset,
                                            self.config.phase, iterations))

    # testing phase
    def test(self):
        rain = tf.placeholder(tf.float32,
                              [None, None, None, self.input_channels],
                              name='test_rain')
        norain = tf.placeholder(tf.float32,
                                [None, None, None, self.input_channels],
                                name='test_norain')

        oups, out, residual = self.derainNet(
            rain, is_training=self.config.is_training)
        finer_out = tf.clip_by_value(out, 0, 1.0)
        finer_residual = tf.clip_by_value(tf.abs(residual), 0, 1.0)

        ssim_val = tf.reduce_mean(self.ssim._ssim(norain, finer_out, 0, 0))
        psnr_val = tf.reduce_mean(self.psnr.compute_psnr(norain, finer_out))

        # load model
        self.saver = tf.train.Saver()
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            try:
                tf.global_variables_initializer().run()
            except:
                tf.initialize_all_variables().run()
            print('[***] fail to load model')

        try:
            test_num, test_data_format, test_label_format = test_dic[
                self.config.test_dataset]
        except:
            print('no testing dataset named {}'.format(
                self.config.test_dataset))
            return

        ssim = []
        psnr = []
        for index in range(1, test_num + 1):
            test_data_fn = test_data_format.format(index)
            test_label_fn = test_label_format.format(index)

            test_data_path = os.path.join(
                self.config.test_path.format(self.config.test_dataset),
                test_data_fn)
            test_label_path = os.path.join(
                self.config.test_path.format(self.config.test_dataset),
                test_label_fn)

            test_data_uint8 = cv2.imread(test_data_path)
            test_label_uint8 = cv2.imread(test_label_path)

            test_data_float = test_data_uint8 / 255.0
            test_label_float = test_label_uint8 / 255.0

            test_data = np.expand_dims(test_data_float, 0)
            test_label = np.expand_dims(test_label_float, 0)

            t = 0
            s_t = time.time()
            finer_out_val, finer_residual_val, tmp_ssim, tmp_psnr = self.sess.run(
                [finer_out, finer_residual, ssim_val, psnr_val],
                feed_dict={
                    rain: test_data,
                    norain: test_label
                })

            e_t = time.time()
            total_t = e_t - s_t
            t = t + total_t

            # save psnr and ssim metrics
            ssim.append(tmp_ssim)
            psnr.append(tmp_psnr)
            # save testing image
            test_label = np.uint8(test_label * 255)
            finer_out_val = np.uint8(finer_out_val * 255)
            finer_residual_val = np.uint8(finer_residual_val * 255)
            save_images(
                finer_out_val, [1, 1],
                '{}/{}_{}'.format(self.config.test_dir,
                                  self.config.test_dataset, test_data_fn))
            save_images(test_label, [1, 1],
                        '{}/{}'.format(self.config.test_dir, test_data_fn))
            save_images(
                finer_residual_val, [1, 1],
                '{}/residual_{}'.format(self.config.test_dir, test_data_fn))
            print('test image {}: ssim:{}, psnr:{} time:{:.4f}'.format(
                test_data_fn, tmp_ssim, tmp_psnr, total_t))

        mean_ssim = np.mean(ssim)
        mean_psnr = np.mean(psnr)
        print('Test phase: ssim:{}, psnr:{}'.format(mean_ssim, mean_psnr))
        print('Average time:{}'.format(t / (test_num - 1)))

    # save model
    @property
    def model_dir(self):
        return "{}_{}_{}".format(self.model_name, self.config.train_dataset,
                                 self.batch_size)

    @property
    def model_pos(self):
        return '{}/{}/{}'.format(self.config.checkpoint_dir, self.model_dir,
                                 self.model_name)

    def save_model(self):
        if not os.path.exists(self.config.checkpoint_dir):
            os.mkdir(self.config.checkpoint_dir)
        self.saver.save(self.sess, self.model_pos)

    def load_model(self):
        if not os.path.isfile(
                os.path.join(self.config.checkpoint_dir, self.model_dir,
                             'checkpoint')):
            return False
        else:
            self.saver.restore(self.sess, self.model_pos)
            return True
Exemplo n.º 2
0
class DerainNet:
    model_name = 'ReMAEN'
    
    '''Derain Net: all the implemented layer are included (e.g. MAEB,
                                                                convGRU
                                                                shared channel attention,
                                                                channel attention).

        Params:
            config: the training configuration
            sess: runing session
    '''
    
    def __init__(self, config, sess=None):
        # config proto
        self.config = config
        self.channel_dim = self.config.channel_dim
        self.batch_size = self.config.batch_size
        self.patch_size = self.config.patch_size
        self.input_channels = self.config.input_channels
        
        # metrics
        self.ssim = SSIM(max_val=1.0)
        self.psnr = PSNR(max_val=1.0)

        # create session
        self.sess = sess
    
    # global average pooling
    def globalAvgPool2D(self, input_x):
        global_avgpool2d = tf.contrib.keras.layers.GlobalAvgPool2D()
        return global_avgpool2d(input_x)
    
    # leaky relu
    def leakyRelu(self, input_x):
        leaky_relu = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)
        return leaky_relu(input_x)

    # squeeze-and-excitation block
    def SEBlock(self, input_x, input_dim=32, reduce_dim=8, scope='SEBlock'):
        with tf.variable_scope(scope) as scope:
            # global scale
            global_pl = self.globalAvgPool2D(input_x)
            reduce_fc1 = slim.fully_connected(global_pl, reduce_dim, activation_fn=tf.nn.relu)
            reduce_fc2 = slim.fully_connected(reduce_fc1, input_dim, activation_fn=None)
            g_scale = tf.nn.sigmoid(reduce_fc2)
            g_scale = tf.expand_dims(g_scale, axis=1)
            g_scale = tf.expand_dims(g_scale, axis=1)
            gs_input = input_x*g_scale
            return gs_input

    # GRU with convolutional version
    def convGRU(self, input_x, h, out_dim, scope='convGRU'):
        with tf.variable_scope(scope):
            if h is None:
                self.conv_xz = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xz')
                self.conv_xn = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xn')
                z = tf.nn.sigmoid(self.conv_xz)
                f = tf.nn.tanh(self.conv_xn)
                h = z*f
            else:
                self.conv_hz = slim.conv2d(h, out_dim, 3, 1, scope='conv_hz')
                self.conv_hr = slim.conv2d(h, out_dim, 3, 1, scope='conv_hr')

                self.conv_xz = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xz')
                self.conv_xr = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xr')
                self.conv_xn = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xn')
                r = tf.nn.sigmoid(self.conv_hr+self.conv_xr)
                z = tf.nn.sigmoid(self.conv_hz+self.conv_xz)
                
                self.conv_hn = slim.conv2d(r*h, out_dim, 3, 1, scope='conv_hn')
                n = tf.nn.tanh(self.conv_xn + self.conv_hn)
                h = (1-z)*h + z*n

        # shared channel attention block
        se = self.SEBlock(h, out_dim, reduce_dim=int(out_dim/4))
        h = self.leakyRelu(se)
        return h, h

    # multi-scale aggregation and enhancement block(MAEB)
    def MAEB(self, input_x, scope_name, dilated_factors=3):
        '''MAEB: multi-scale aggregation and enhancement block
            Params:
                input_x: input data
                scope_name: the scope name of the MAEB (customer definition)
                dilated_factor: the maximum number of dilated factors(default=3, range from 1 to 3)

            Return:
                return the output the MAEB
                
            Input shape:
                4D tensor with shape '(batch_size, height, width, channels)'
                
            Output shape:
                4D tensor with shape '(batch_size, height, width, channels)'
        '''
        dilate_c = []  
        with tf.variable_scope(scope_name):
            for i in range(1,dilated_factors+1):
                d1 = self.leakyRelu(slim.conv2d(input_x, self.channel_dim, 3, 1, rate=i, activation_fn=None, scope='d1'))
                d2 = self.leakyRelu(slim.conv2d(d1, self.channel_dim, 3, 1, rate=i, activation_fn=None, scope='d2'))
                dilate_c.append(d2)

            add = tf.add_n(dilate_c)
            shape = add.get_shape().as_list()
            output = self.SEBlock(add, shape[-1], reduce_dim=int(shape[-1]/4))
            return output

    # multi-scale aggregation and enhancement network
    def derainNet(self, input_x, scope_name='derainNet'):    
        '''ReMAEN: recurrent multi-scale aggregation and enhancement network
            Params:
                input_x: input data
                scope_name: the scope name of the ReMAEN (customer definition, default='derainnet')
            Return:
                return the derained results

            Input shape:
                4D tensor with shape '(batch_size, height, width, channels)'
                
            Output shape:
                4D tensor with shape '(batch_size, height, width, channels)'            
        '''
        # reuse: tf.AUTO_REUSE(such setting will enable the network to reuse parameters automatically)
        with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE):
            with slim.arg_scope([slim.conv2d,slim.conv2d_transpose], weights_initializer=tf.contrib.layers.xavier_initializer(),
                                              normalizer_fn = None,
                                              activation_fn = None,
                                              padding='SAME'):
                old_states = [None for _ in range(7)]
                stages = 3
                derain = input_x

                for i in range(stages):
                    cur_states = []
                    with tf.variable_scope('ReMAEN'):
                        with tf.variable_scope('extracting_path'):
                            MAEB1 = self.MAEB(derain, scope_name='MAEB1')
                            gru1, h1 = self.convGRU(MAEB1, old_states[0], self.channel_dim, scope='convGRU1')
                            cur_states.append(h1)

                            MAEB2 = self.MAEB(gru1, scope_name='MAEB2')
                            gru2, h2 = self.convGRU(MAEB2, old_states[1], self.channel_dim, scope='convGRU2')
                            cur_states.append(h2)
                            
                            MAEB3 = self.MAEB(gru2, scope_name='MAEB3')
                            gru3, h3 = self.convGRU(MAEB3, old_states[2], self.channel_dim, scope='convGRU3')
                            cur_states.append(h3)

                            MAEB4 = self.MAEB(gru3, scope_name='MAEB4')
                            gru4, h4 = self.convGRU(MAEB4, old_states[3], self.channel_dim, scope='convGRU4')
                            cur_states.append(h4)
                            
                        with tf.variable_scope('responding_path'):
                            up5 = slim.conv2d(gru4, self.channel_dim, 3, 1, activation_fn=tf.nn.relu, scope='conv5')
                            add5 = tf.add(up5, MAEB3)
                            gru5, h5 = self.convGRU(add5, old_states[4], self.channel_dim, scope='convGRU5')
                            cur_states.append(h5)
                            
                            up6 = slim.conv2d(gru5, self.channel_dim, 3, 1, activation_fn=tf.nn.relu, scope='conv6')
                            add6 = tf.add(up6, MAEB2)
                            gru6, h6 = self.convGRU(add6, old_states[5], self.channel_dim, scope='convGRU6')
                            cur_states.append(h6)
                            
                            up7 = slim.conv2d(gru6, self.channel_dim, 3, 1, activation_fn=tf.nn.relu, scope='conv7')
                            add7 = tf.add(up7, MAEB1)
                            gru7, h7 = self.convGRU(add7, old_states[6], self.channel_dim, scope='convGRU7')
                            cur_states.append(h7)
                        
                    # residual map generator
                    with tf.variable_scope('RMG'):
                        rmg_conv = slim.conv2d(gru7, self.channel_dim, 3, 1)
                        rmg_conv_se = self.leakyRelu(self.SEBlock(rmg_conv, self.channel_dim, reduce_dim=int(self.channel_dim/4)))
                        residual = slim.conv2d(rmg_conv_se, self.input_channels, 3, 1)
                    
                    derain = derain - residual
                    old_states = [tf.identity(s) for s in cur_states]

        return derain, residual
    
    def build(self):
        # placeholder
        self.rain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='rain')
        self.norain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='norain')
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        
        # derainnet
        self.out, self.residual = self.derainNet(self.rain)
        self.finer_out = tf.clip_by_value(self.out, 0, 1.0)
        self.finer_residual = tf.clip_by_value(tf.abs(self.residual), 0, 1)
        
        # metrics
        self.ssim_finer_tensor = tf.reduce_mean(self.ssim._ssim(self.norain, self.out, 0, 0))
        self.psnr_finer_tensor = tf.reduce_mean(self.psnr.compute_psnr(self.norain, self.out))
        self.ssim_val = tf.reduce_mean(self.ssim._ssim(self.norain, self.finer_out, 0, 0))
        self.psnr_val = tf.reduce_mean(self.psnr.compute_psnr(self.norain, self.finer_out))
        
        # loss function
        # MSE loss
        self.l2_loss = tf.reduce_mean(tf.square(self.out - self.norain))
        # edge loss, kernel is imported from settings
        self.norain_edge = tf.nn.relu(tf.nn.conv2d(tf.image.rgb_to_grayscale(self.norain), kernel, [1,1,1,1],padding='SAME'))
        self.derain_edge = tf.nn.relu(tf.nn.conv2d(tf.image.rgb_to_grayscale(self.out), kernel, [1,1,1,1],padding='SAME'))
        self.edge_loss = tf.reduce_mean(tf.square(self.norain_edge-self.derain_edge))
        # total loss
        self.total_loss = self.l2_loss + 0.1*self.edge_loss
        
        # optimization
        t_vars = tf.trainable_variables()
        g_vars = [var for var in t_vars if 'derainNet' in var.name]
        self.train_ops = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=self.config.beta1, beta2=self.config.beta2).minimize(self.total_loss, var_list=g_vars)
        
        # summary
        self.l2_loss_summary = tf.summary.scalar('l2_loss', self.l2_loss)
        self.total_loss_summary = tf.summary.scalar('total_loss', self.total_loss)
        self.edge_loss_summary = tf.summary.scalar('edge_loss', self.edge_loss)
        self.ssim_summary = tf.summary.scalar('ssim', self.ssim_val)
        self.psnr_summary = tf.summary.scalar('psnr', self.psnr_val)
        self.summaries = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(self.config.logs_dir, self.sess.graph)
        
        # saver
        global_variables = tf.global_variables()
        var_to_store = [var for var in global_variables if 'derainNet' in var.name]
        self.saver = tf.train.Saver(var_list=var_to_store)

        # trainable variables
        num_params = 0
        for var in g_vars:
            tmp_num = 1
            for i in var.get_shape().as_list():
                tmp_num = tmp_num*i
            num_params = num_params + tmp_num
        print('numbers of trainable parameters:{}'.format(num_params))

    # training phase
    def train(self):
        # initialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        # load training model
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            print('[***] fail to load model')
        
        lr_ = self.config.lr
        start_time = time.time()
        for counter in range(self.config.iterations):
            if counter == 30000:
                lr_ = 0.1*lr_

            # obtain training image pairs
            img, label = read_data(self.config.train_dataset, self.config.data_path, self.batch_size, self.patch_size, self.config.trainset_size)
            _, total_loss, summaries, ssim, psnr = self.sess.run([self.train_ops,
                                                               self.total_loss,
                                                               self.summaries,
                                                               self.ssim_val,
                                                               self.psnr_val], feed_dict={self.rain:img,
                                                                                           self.norain:label,
                                                                                           self.lr:lr_})

            print('Iteration:{}, phase:{}, loss:{:.4f}, ssim:{:.4f}, psnr:{:.4f}, lr:{}, iterations:{}'.format(counter,
                                                                                                                 self.config.phase,
                                                                                                                 total_loss,
                                                                                                                 ssim,
                                                                                                                 psnr,
                                                                                                                 lr_,
                                                                                                                 self.config.iterations))
                                
            self.summary_writer.add_summary(summaries, global_step=counter)
            if np.mod(counter, 100)==0:
                self.sample(self.config.sample_dir, counter)

            if np.mod(counter, 500)==0:
                self.save_model()
        
        # save final model
        if counter == self.config.iterations-1:
            self.save_model()

        # training time
        end_time = time.time()
        print('training time:{} hours'.format((end_time-start_time)/3600.0))

    # sampling phase
    def sample(self, sample_dir, iterations):
        # obtaining sampling image pairs
        test_img, test_label = read_data(self.config.test_dataset, self.config.data_path, self.batch_size, self.patch_size, self.config.testset_size)
        finer_out, finer_residual = self.sess.run([self.finer_out, self.finer_residual], feed_dict={self.rain:test_img})
        
        # save sampling images
        test_img_uint8 = np.uint8(test_img*255.0)
        test_label_uint8 = np.uint8(test_label*255.0)
        finer_out_uint8 = np.uint8(finer_out*255.0)
        finer_residual = np.uint8(finer_residual*255.0)
        sample = np.concatenate([test_img_uint8, test_label_uint8, finer_out_uint8, finer_residual], 2)
        save_images(sample, [int(np.sqrt(self.batch_size))+1, int(np.sqrt(self.batch_size))+1], '{}/{}_{}_{:04d}.jpg'.format(self.config.sample_dir,
                                                                                                                             self.config.test_dataset,
                                                                                                                             self.config.phase,
                                                                                                                             iterations))
    
    # testing phase
    def test(self):
        rain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='test_rain')
        norain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='test_norain')
        
        out, residual = self.derainNet(rain)
        finer_out = tf.clip_by_value(out, 0, 1.0)
        finer_residual = tf.clip_by_value(tf.abs(residual), 0, 1.0)

        ssim_val = tf.reduce_mean(self.ssim._ssim(norain, finer_out, 0, 0))
        psnr_val = tf.reduce_mean(self.psnr.compute_psnr(norain, finer_out))

        # load model
        self.saver = tf.train.Saver()
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            try:
                tf.global_variables_initializer().run()
            except:
                tf.initialize_all_variables().run()
            print('[***] fail to load model')

        try:
            test_num, test_data_format, test_label_format = test_dic[self.config.test_dataset]
        except:
            print('no testing dataset named {}'.format(self.config.test_dataset))
            return

        ssim = []
        psnr = []
        for index in range(1, test_num+1):
            test_data_fn = test_data_format.format(index)
            test_label_fn = test_label_format.format(index)
            
            test_data_path = os.path.join(self.config.test_path.format(self.config.test_dataset), test_data_fn)
            test_label_path = os.path.join(self.config.test_path.format(self.config.test_dataset), test_label_fn)

            test_data_uint8 = cv2.imread(test_data_path)
            test_label_uint8 = cv2.imread(test_label_path)

            test_data_float = test_data_uint8/255.0
            test_label_float = test_label_uint8/255.0
            
            test_data = np.expand_dims(test_data_float, 0)
            test_label = np.expand_dims(test_label_float, 0)
            
            t = 0
            s_t = time.time()
            finer_out_val, finer_residual_val, tmp_ssim, tmp_psnr = self.sess.run([finer_out,
                                                                                   finer_residual,
                                                                                   ssim_val,
                                                                                   psnr_val] , feed_dict={rain:test_data,
                                                                                                          norain:test_label})

            e_t = time.time()            
            total_t = e_t - s_t
            t = t + total_t

            # save psnr and ssim metrics
            ssim.append(tmp_ssim)
            psnr.append(tmp_psnr)
            # save testing image
            test_label = np.uint8(test_label*255)
            finer_out_val = np.uint8(finer_out_val*255)
            finer_residual_val = np.uint8(finer_residual_val*255)
            save_images(finer_out_val, [1,1], '{}/{}_{}'.format(self.config.test_dir, self.config.test_dataset, test_data_fn))
            save_images(test_label, [1,1], '{}/{}'.format(self.config.test_dir, test_data_fn))
            save_images(finer_residual_val, [1,1], '{}/residual_{}'.format(self.config.test_dir, test_data_fn))
            print('test image {}: ssim:{}, psnr:{} time:{:.4f}'.format(test_data_fn, tmp_ssim, tmp_psnr, total_t))
        
        mean_ssim = np.mean(ssim)
        mean_psnr = np.mean(psnr)
        print('Test phase: ssim:{}, psnr:{}'.format(mean_ssim, mean_psnr))
        print('Average time:{}'.format(t/(test_num-1)))

    # save model            
    @property
    def model_dir(self):
        return "{}_{}_{}".format(
            self.model_name, self.config.train_dataset,
            self.batch_size)
    @property
    def model_pos(self):
        return '{}/{}/{}'.format(self.config.checkpoint_dir, self.model_dir, self.model_name)

    def save_model(self):
        if not os.path.exists(self.config.checkpoint_dir):
            os.mkdir(self.config.checkpoint_dir)
        self.saver.save(self.sess, self.model_pos)
        
    def load_model(self):
        if not os.path.isfile(os.path.join(self.config.checkpoint_dir, self.model_dir,'checkpoint')):
            return False
        else:
            self.saver.restore(self.sess, self.model_pos)
            return True