예제 #1
0
def train_model(dataset_dir, weights_path=None):
    """

    :param dataset_dir:
    :param gpu_id:
    :param weights_path:
    :return:
    """
    # 构建数据集
    with tf.device('/cpu:0'):

        train_dataset = data_feed_pipline.DerainDataFeeder(
            dataset_dir=dataset_dir, flags='train')
        val_dataset = data_feed_pipline.DerainDataFeeder(
            dataset_dir=dataset_dir, flags='val')

        train_input_tensor, train_label_tensor, train_mask_tensor = train_dataset.inputs(
            CFG.TRAIN.BATCH_SIZE, 1)
        val_input_tensor, val_label_tensor, val_mask_tensor = val_dataset.inputs(
            CFG.TRAIN.BATCH_SIZE, 1)

    with tf.device('/gpu:1'):

        # define network
        derain_net = derain_drop_net.DeRainNet(
            phase=tf.constant('train', dtype=tf.string))

        # calculate train loss and validation loss
        train_gan_loss, train_discriminative_loss, train_net_output = derain_net.compute_loss(
            input_tensor=train_input_tensor,
            gt_label_tensor=train_label_tensor,
            mask_label_tensor=train_mask_tensor,
            name='derain_net',
            reuse=False)

        val_gan_loss, val_discriminative_loss, val_net_output = derain_net.compute_loss(
            input_tensor=val_input_tensor,
            gt_label_tensor=val_label_tensor,
            mask_label_tensor=val_mask_tensor,
            name='derain_net',
            reuse=True)

        # calculate train ssim, psnr and validation ssim, psnr
        train_label_tensor_scale = tf.image.convert_image_dtype(
            image=(train_label_tensor + 1.0) / 2.0, dtype=tf.uint8)
        train_net_output_tensor_scale = tf.image.convert_image_dtype(
            image=(train_net_output + 1.0) / 2.0, dtype=tf.uint8)
        val_label_tensor_scale = tf.image.convert_image_dtype(
            image=(val_label_tensor + 1.0) / 2.0, dtype=tf.uint8)
        val_net_output_tensor_scale = tf.image.convert_image_dtype(
            image=(val_net_output + 1.0) / 2.0, dtype=tf.uint8)

        train_label_tensor_scale = tf.image.rgb_to_grayscale(
            images=tf.reverse(train_label_tensor_scale, axis=[-1]))
        train_net_output_tensor_scale = tf.image.rgb_to_grayscale(
            images=tf.reverse(train_net_output_tensor_scale, axis=[-1]))
        val_label_tensor_scale = tf.image.rgb_to_grayscale(
            images=tf.reverse(val_label_tensor_scale, axis=[-1]))
        val_net_output_tensor_scale = tf.image.rgb_to_grayscale(
            images=tf.reverse(val_net_output_tensor_scale, axis=[-1]))

        train_ssim = tf.reduce_mean(tf.image.ssim(
            train_label_tensor_scale,
            train_net_output_tensor_scale,
            max_val=255),
                                    name='avg_train_ssim')
        train_psnr = tf.reduce_mean(tf.image.psnr(
            train_label_tensor_scale,
            train_net_output_tensor_scale,
            max_val=255),
                                    name='avg_train_psnr')
        val_ssim = tf.reduce_mean(tf.image.ssim(val_label_tensor_scale,
                                                val_net_output_tensor_scale,
                                                max_val=255),
                                  name='avg_val_ssim')
        val_psnr = tf.reduce_mean(tf.image.psnr(val_label_tensor_scale,
                                                val_net_output_tensor_scale,
                                                max_val=255),
                                  name='avg_val_psnr')

        # collect trainable vars to update
        train_vars = tf.trainable_variables()

        d_vars = [
            tmp for tmp in train_vars if 'discriminative_loss' in tmp.name
        ]
        g_vars = [
            tmp for tmp in train_vars
            if 'attentive_' in tmp.name and 'vgg_feats' not in tmp.name
        ]
        vgg_vars = [tmp for tmp in train_vars if "vgg_feats" in tmp.name]

        # set optimizer
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                                   global_step,
                                                   100000,
                                                   0.1,
                                                   staircase=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            d_optim = tf.train.AdamOptimizer(learning_rate).minimize(
                train_discriminative_loss, var_list=d_vars)
            g_optim = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=tf.constant(0.9,
                                     tf.float32)).minimize(train_gan_loss,
                                                           var_list=g_vars)

        # Set tf saver
        saver = tf.train.Saver()
        model_save_dir = 'model/derain_gan'
        if not ops.exists(model_save_dir):
            os.makedirs(model_save_dir)
        train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                         time.localtime(time.time()))
        model_name = 'derain_gan_{:s}.ckpt'.format(str(train_start_time))
        model_save_path = ops.join(model_save_dir, model_name)

        # Set tf summary
        tboard_save_path = 'tboard/derain_gan'
        if not ops.exists(tboard_save_path):
            os.makedirs(tboard_save_path)

        train_g_loss_scalar = tf.summary.scalar(name='train_gan_loss',
                                                tensor=train_gan_loss)
        train_d_loss_scalar = tf.summary.scalar(
            name='train_discriminative_loss', tensor=train_discriminative_loss)
        train_ssim_scalar = tf.summary.scalar(name='train_image_ssim',
                                              tensor=train_ssim)
        train_psnr_scalar = tf.summary.scalar(name='train_image_psnr',
                                              tensor=train_psnr)
        val_g_loss_scalar = tf.summary.scalar(name='val_gan_loss',
                                              tensor=val_gan_loss)
        val_d_loss_scalar = tf.summary.scalar(name='val_discriminative_loss',
                                              tensor=val_discriminative_loss)
        val_ssim_scalar = tf.summary.scalar(name='val_image_ssim',
                                            tensor=val_ssim)
        val_psnr_scalar = tf.summary.scalar(name='val_image_psnr',
                                            tensor=val_psnr)

        lr_scalar = tf.summary.scalar(name='learning_rate',
                                      tensor=learning_rate)

        train_summary_op = tf.summary.merge([
            val_g_loss_scalar, val_d_loss_scalar, val_ssim_scalar,
            val_psnr_scalar
        ])
        val_summary_op = tf.summary.merge([
            train_g_loss_scalar, train_d_loss_scalar, train_ssim_scalar,
            train_psnr_scalar, lr_scalar
        ])

        # Set sess configuration
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
        sess_config.gpu_options.allocator_type = 'BFC'

        sess = tf.Session(config=sess_config)

        summary_writer = tf.summary.FileWriter(tboard_save_path)
        summary_writer.add_graph(sess.graph)

        # Set the training parameters
        train_epochs = CFG.TRAIN.EPOCHS

        log.info('Global configuration is as follows:')
        log.info(CFG)

        with sess.as_default():

            tf.train.write_graph(
                graph_or_graph_def=sess.graph,
                logdir='',
                name='{:s}/derain_gan.pb'.format(model_save_dir))

            if weights_path is None:
                log.info('Training from scratch')
                init = tf.global_variables_initializer()
                sess.run(init)
            else:
                log.info(
                    'Restore model from last model checkpoint {:s}'.format(
                        weights_path))
                saver.restore(sess=sess, save_path=weights_path)

            # 加载预训练参数
            pretrained_weights = np.load('./data/vgg16.npy',
                                         encoding='latin1').item()

            for vv in vgg_vars:
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

            # train loop
            for epoch in range(1, train_epochs + 1):
                # training part
                t_start = time.time()

                # update network and calculate loss and evaluate statics
                d_op, g_op, train_d_loss, train_g_loss, train_avg_ssim, \
                train_avg_psnr, train_summary, val_summary = sess.run(
                    [d_optim, g_optim, train_discriminative_loss, train_gan_loss, train_ssim,
                     train_psnr, train_summary_op, val_summary_op]
                )

                summary_writer.add_summary(train_summary, global_step=epoch)
                summary_writer.add_summary(val_summary, global_step=epoch)

                cost_time = time.time() - t_start

                log.info('Epoch_Train: {:d} D_loss: {:.5f} G_loss: '
                         '{:.5f} SSIM: {:.5f} PSNR: {:.5f} Cost_time: {:.5f}s'.
                         format(epoch, train_d_loss, train_g_loss,
                                train_avg_ssim, train_avg_psnr, cost_time))

                # Evaluate model
                if epoch % 500 == 0:
                    val_d_loss, val_g_loss, val_avg_ssim, val_avg_psnr = sess.run(
                        [
                            val_discriminative_loss, val_gan_loss, val_ssim,
                            val_psnr
                        ])
                    log.info(
                        'Epoch_Val: {:d} D_loss: {:.5f} G_loss: '
                        '{:.5f} SSIM: {:.5f} PSNR: {:.5f} Cost_time: {:.5f}s'.
                        format(epoch, val_d_loss, val_g_loss, val_avg_ssim,
                               val_avg_psnr, cost_time))

                # Save Model
                if epoch % CFG.TRAIN.MODEL_SAVE_STEP == 0:
                    saver.save(sess=sess,
                               save_path=model_save_path,
                               global_step=epoch)

        sess.close()

    return
def test_model(image_path, weights_path, label_path=None):
    """

    :param image_path:
    :param weights_path:
    :param label_path:
    :return:
    """
    assert ops.exists(image_path)

    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[
                                      CFG.TEST.BATCH_SIZE, CFG.TEST.IMG_HEIGHT,
                                      CFG.TEST.IMG_WIDTH, 3
                                  ],
                                  name='input_tensor')

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT),
                       interpolation=cv2.INTER_LINEAR)
    image_vis = image
    image = np.divide(np.array(image, np.float32), 127.5) - 1.0

    label_image_vis = None
    if label_path is not None:
        label_image = cv2.imread(label_path, cv2.IMREAD_COLOR)
        label_image_vis = cv2.resize(label_image,
                                     (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT),
                                     interpolation=cv2.INTER_LINEAR)

    phase = tf.constant('test', tf.string)

    net = derain_drop_net.DeRainNet(phase=phase)
    output, attention_maps = net.inference(input_tensor=input_tensor,
                                           name='derain_net')

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    saver = tf.train.Saver()

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        output_image, atte_maps = sess.run(
            [output, attention_maps],
            feed_dict={input_tensor: np.expand_dims(image, 0)})

        output_image = output_image[0]
        for i in range(output_image.shape[2]):
            output_image[:, :, i] = minmax_scale(output_image[:, :, i])

        output_image = np.array(output_image, np.uint8)

        if label_path is not None:
            label_image_vis_gray = cv2.cvtColor(label_image_vis,
                                                cv2.COLOR_BGR2YCR_CB)[:, :, 0]
            output_image_gray = cv2.cvtColor(output_image,
                                             cv2.COLOR_BGR2YCR_CB)[:, :, 0]
            psnr = compare_psnr(label_image_vis_gray, output_image_gray)
            ssim = compare_ssim(label_image_vis_gray, output_image_gray)

            print('SSIM: {:.5f}'.format(ssim))
            print('PSNR: {:.5f}'.format(psnr))

        # 保存并可视化结果
        cv2.imwrite('src_img.png', image_vis)
        cv2.imwrite('derain_ret.png', output_image)

        plt.figure('src_image')
        plt.imshow(image_vis[:, :, (2, 1, 0)])
        plt.figure('derain_ret')
        plt.imshow(output_image[:, :, (2, 1, 0)])
        plt.figure('atte_map_1')
        plt.imshow(atte_maps[0][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_1.png')
        plt.figure('atte_map_2')
        plt.imshow(atte_maps[1][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_2.png')
        plt.figure('atte_map_3')
        plt.imshow(atte_maps[2][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_3.png')
        plt.figure('atte_map_4')
        plt.imshow(atte_maps[3][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_4.png')
        plt.show()

    return
def build_saved_model(ckpt_path, export_dir):
    """
    Convert source ckpt weights file into tensorflow saved model
    :param ckpt_path:
    :param export_dir:
    :return:
    """

    if ops.exists(export_dir):
        raise ValueError('Export dir must be a dir path that does not exist')

    assert ops.exists(ops.split(ckpt_path)[0])

    # build inference tensorflow graph
    image_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[1, CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3],
                                  name='input_tensor')
    # set nsfw net
    phase = tf.constant('test', dtype=tf.string)
    derain_net = derain_drop_net.DeRainNet(phase=phase)

    # compute inference logits
    output, attention_maps = derain_net.inference(input_tensor=image_tensor, name='derain_net')

    # scale image
    output = tf.squeeze(output, 0)
    b, g, r = tf.split(output, num_or_size_splits=3, axis=-1)
    scaled_channel = []
    for channel in [b, g, r]:
        tmp = (channel - tf.reduce_min(channel)) * 255.0 / (tf.reduce_max(channel) - tf.reduce_min(channel))
        scaled_channel.append(tmp)
    output = tf.concat(values=scaled_channel, axis=-1)
    output = tf.cast(output, tf.uint8, name='derain_image_result')

    # set tensorflow saver
    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(device_count={"GPU": 0})
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        saver.restore(sess=sess, save_path=ckpt_path)

        # set model save builder
        saved_builder = sm.builder.SavedModelBuilder(export_dir)

        # add tensor need to be saved
        saved_input_tensor = sm.utils.build_tensor_info(image_tensor)
        saved_prediction_tensor = sm.utils.build_tensor_info(output)

        # build SignatureDef protobuf
        signatur_def = sm.signature_def_utils.build_signature_def(
            inputs={'input_tensor': saved_input_tensor},
            outputs={'prediction': saved_prediction_tensor},
            method_name='derain_predict'
        )

        # add graph into MetaGraphDef protobuf
        saved_builder.add_meta_graph_and_variables(
            sess,
            tags=[sm.tag_constants.SERVING],
            signature_def_map={sm.signature_constants.REGRESS_INPUTS: signatur_def}
        )

        # save model
        saved_builder.save()

    return
예제 #4
0
def test_model(image_path, weights_path):
    """

    :param image_path:
    :param weights_path:
    :return:
    """
    assert ops.exists(image_path)

    with tf.device('/gpu:0'):
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          CFG.TEST.BATCH_SIZE,
                                          CFG.TEST.IMG_HEIGHT,
                                          CFG.TEST.IMG_WIDTH, 3
                                      ],
                                      name='input_tensor')

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT))
    image_vis = image
    image = np.divide(image, 127.5) - 1

    phase = tf.constant('test', tf.string)

    with tf.device('/gpu:0'):
        net = derain_drop_net.DeRainNet(phase=phase)
        output, attention_maps = net.build(input_tensor=input_tensor,
                                           name='derain_net_loss')

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    saver = tf.train.Saver()

    with tf.device('/gpu:0'):
        with sess.as_default():
            saver.restore(sess=sess, save_path=weights_path)

            output_image, atte_maps = sess.run(
                [output, attention_maps],
                feed_dict={input_tensor: np.expand_dims(image, 0)})

            output_image = output_image[0]
            for i in range(output_image.shape[2]):
                output_image[:, :, i] = minmax_scale(output_image[:, :, i])

            output_image = np.array(output_image, np.uint8)

            # Image metrics计算
            image_ssim = ssim(image_vis,
                              output_image,
                              data_range=output_image.max() -
                              output_image.min(),
                              multichannel=True)
            image_psnr = psnr(image_vis,
                              output_image,
                              data_range=output_image.max() -
                              output_image.min())

            print('Image ssim: {:.5f}'.format(image_ssim))
            print('Image psnr: {:.5f}'.format(image_psnr))

            # 保存并可视化结果
            cv2.imwrite('src_img.png', image_vis)
            cv2.imwrite('derain_ret.png', output_image)

            plt.figure('src_image')
            plt.imshow(image_vis[:, :, (2, 1, 0)])
            plt.figure('derain_ret')
            plt.imshow(output_image[:, :, (2, 1, 0)])
            plt.figure('atte_map_1')
            plt.imshow(atte_maps[0][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_1.png')
            plt.figure('atte_map_2')
            plt.imshow(atte_maps[1][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_2.png')
            plt.figure('atte_map_3')
            plt.imshow(atte_maps[2][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_3.png')
            plt.figure('atte_map_4')
            plt.imshow(atte_maps[3][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_4.png')
            plt.show()

    return
예제 #5
0
def train_model(dataset_dir, weights_path=None):
    """

    :param dataset_dir:
    :param gpu_id:
    :param weights_path:
    :return:
    """

    # 构建数据集
    with tf.device('/gpu:1'):
        train_dataset = data_provider.DataSet(ops.join(dataset_dir, 'train.txt'))

        # 声明tensor
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
                                      name='input_tensor')
        label_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
                                      name='label_tensor')
        mask_tensor = tf.placeholder(dtype=tf.float32,
                                     shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1],
                                     name='mask_tensor')
        lr_tensor = tf.placeholder(dtype=tf.float32,
                                   shape=[],
                                   name='learning_rate')
        phase_tensor = tf.placeholder(dtype=tf.string, shape=[], name='phase')

        # 声明ssim计算类
        ssim_computer = tf_ssim.SsimComputer()

        # 声明网络
        derain_net = derain_drop_net.DeRainNet(phase=phase_tensor)

        gan_loss, discriminative_loss, net_output = derain_net.compute_loss(
            input_tensor=input_tensor,
            gt_label_tensor=label_tensor,
            mask_label_tensor=mask_tensor,
            name='derain_net_loss')

        train_vars = tf.trainable_variables()

        ssim = ssim_computer.compute_ssim(tf.image.rgb_to_grayscale(net_output),
                                          tf.image.rgb_to_grayscale(label_tensor))

        d_vars = [tmp for tmp in train_vars if 'discriminative_loss' in tmp.name]
        g_vars = [tmp for tmp in train_vars if 'attentive_' in tmp.name and 'vgg_feats' not in tmp.name]
        vgg_vars = [tmp for tmp in train_vars if "vgg_feats" in tmp.name]

        d_optim = tf.train.AdamOptimizer(lr_tensor).minimize(discriminative_loss, var_list=d_vars)
        g_optim = tf.train.MomentumOptimizer(learning_rate=lr_tensor,
                                             momentum=tf.constant(0.9, tf.float32)).minimize(gan_loss, var_list=g_vars)

        # Set tf saver
        saver = tf.train.Saver()
        model_save_dir = 'model/derain_gan_v3'
        if not ops.exists(model_save_dir):
            os.makedirs(model_save_dir)
        train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        model_name = 'derain_gan_v3_{:s}.ckpt'.format(str(train_start_time))
        model_save_path = ops.join(model_save_dir, model_name)

        # Set tf summary
        tboard_save_path = 'tboard/derain_gan_v3'
        if not ops.exists(tboard_save_path):
            os.makedirs(tboard_save_path)
        g_loss_scalar = tf.summary.scalar(name='gan_loss', tensor=gan_loss)
        d_loss_scalar = tf.summary.scalar(name='discriminative_loss', tensor=discriminative_loss)
        ssim_scalar = tf.summary.scalar(name='image_ssim', tensor=ssim)
        lr_scalar = tf.summary.scalar(name='learning_rate', tensor=lr_tensor)
        d_summary_op = tf.summary.merge([d_loss_scalar, lr_scalar])
        g_summary_op = tf.summary.merge([g_loss_scalar, ssim_scalar])

        # Set sess configuration
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
        sess_config.gpu_options.allocator_type = 'BFC'

        sess = tf.Session(config=sess_config)

        summary_writer = tf.summary.FileWriter(tboard_save_path)
        summary_writer.add_graph(sess.graph)

        # Set the training parameters
        train_epochs = CFG.TRAIN.EPOCHS

        log.info('Global configuration is as follows:')
        log.info(CFG)

        with sess.as_default():

            tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
                                 name='{:s}/derain_gan.pb'.format(model_save_dir))

            if weights_path is None:
                log.info('Training from scratch')
                init = tf.global_variables_initializer()
                sess.run(init)
            else:
                log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
                saver.restore(sess=sess, save_path=weights_path)

            # 加载预训练参数
            pretrained_weights = np.load(
                './data/vgg16.npy',
                encoding='latin1').item()

            for vv in vgg_vars:
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

            # train loop
            for epoch in range(train_epochs):
                # training part
                t_start = time.time()

                gt_imgs, label_imgs, mask_imgs = train_dataset.next_batch(CFG.TRAIN.BATCH_SIZE)

                mask_imgs = [np.expand_dims(tmp, axis=-1) for tmp in mask_imgs]

                # Update discriminative Network
                _, d_loss, d_summary = sess.run(
                    [d_optim, discriminative_loss, d_summary_op],
                    feed_dict={input_tensor: gt_imgs,
                               label_tensor: label_imgs,
                               mask_tensor: mask_imgs,
                               lr_tensor: CFG.TRAIN.LEARNING_RATE,
                               phase_tensor: 'train'})

                # Update attentive gan Network
                _, g_loss, g_summary, ssim_val = sess.run(
                    [g_optim, gan_loss, g_summary_op, ssim],
                    feed_dict={input_tensor: gt_imgs,
                               label_tensor: label_imgs,
                               mask_tensor: mask_imgs,
                               lr_tensor: CFG.TRAIN.LEARNING_RATE,
                               phase_tensor: 'train'})

                summary_writer.add_summary(d_summary, global_step=epoch)
                summary_writer.add_summary(g_summary, global_step=epoch)

                cost_time = time.time() - t_start

                log.info('Epoch: {:d} D_loss: {:.5f} G_loss: '
                         '{:.5f} Ssim: {:.5f} Cost_time: {:.5f}s'.format(epoch, d_loss, g_loss,
                                                                         ssim_val, cost_time))
                if epoch % 2000 == 0:
                    saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
        sess.close()

    return
예제 #6
0
def test_model_TF(image_path, weights_path, save_path, label_path=None):
    """
    :param image_path:
    :param weights_path:
    :param label_path:
    :return:
    """

    '''1st part: define Graph'''
    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[CFG.TEST.BATCH_SIZE, CFG.TEST.IMG_HEIGHT, CFG.TEST.IMG_WIDTH, 3],
                                  name='input_tensor'
                                  )
    phase = tf.constant('test', tf.string)

    net = derain_drop_net.DeRainNet(phase=phase)
    output, attention_maps = net.inference(input_tensor=input_tensor, name='derain_net')

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=False)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'
    sess = tf.Session(config=sess_config)
    saver = tf.train.Saver()

    '''2nd part: prepare data for Graph'''
    assert ops.exists(image_path)

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR)
    image_vis = image
    image = np.divide(np.array(image, np.float32), 127.5) - 1.0

    label_image_vis = None
    if label_path is not None:
        label_image = cv2.imread(label_path, cv2.IMREAD_COLOR)
        label_image_vis = cv2.resize(
            label_image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR
        )

    '''3rd part: run Graph'''
    ssim = 0
    psnr = 0
    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)
        tf.reset_default_graph()
        output_image, atte_maps = sess.run(
            [output, attention_maps],
            feed_dict={input_tensor: np.expand_dims(image, 0)})

        output_image = output_image[0]
        for i in range(output_image.shape[2]):
            output_image[:, :, i] = minmax_scale(output_image[:, :, i])

        output_image = np.array(output_image, np.uint8)

        if label_path is not None:
            label_image_vis_gray = cv2.cvtColor(label_image_vis, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
            output_image_gray = cv2.cvtColor(output_image, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
            psnr = compare_psnr(label_image_vis_gray, output_image_gray)
            ssim = compare_ssim(label_image_vis_gray, output_image_gray)

            print('SSIM: {:.5f}'.format(ssim))
            print('PSNR: {:.5f}'.format(psnr))

        # 保存并可视化结果
        image_name = re.search(r'[^/]+$',image_path)[0][:-4]
        cv2.imwrite(save_path+image_name+'_Done.png', output_image)

    return ssim,psnr
예제 #7
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.framework import graph_util

from attentive_gan_model import derain_drop_net

MODEL_WEIGHTS_FILE_PATH = './test.ckpt'

# construct compute graph
input_tensor = tf.placeholder(dtype=tf.float32,
                              shape=[1, 240, 360, 3],
                              name='input_tensor')
net = derain_drop_net.DeRainNet(phase=tf.constant('test', tf.string))
output, attention_maps = net.inference(input_tensor=input_tensor,
                                       name='derain_net')
output = tf.squeeze(output, axis=0, name='final_output')
# attention_maps = tf.squeeze(attention_maps, axis=0, name='final_attention_maps')

# create a session
saver = tf.train.Saver()

sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.85
sess_config.gpu_options.allow_growth = False
sess_config.gpu_options.allocator_type = 'BFC'

sess = tf.Session(config=sess_config)