Exemple #1
0
def get_theme_channels(theme_img, type):
    if type == "jpg":
        theme_img = tf.image.decode_jpeg(theme_img, channels=3)
    if type == "bmp":
        theme_img = tf.image.decode_bmp(theme_img, channels=3)
    theme_img = tf.image.resize_images(theme_img, [1, 5])
    theme_img = tf.cast(theme_img, tf.float64) / 255.0
    theme_img = input_data.rgb_to_lab(theme_img)
    theme_img = tf.cast(theme_img, tf.float32)
    theme_ab = (theme_img[:, :, 1:] + 128) / 255
    theme_ab = tf.reshape(theme_ab, [1, 1, 5, 2])
    return theme_ab
Exemple #2
0
def get_lab_channel(image, image_size, type):
    if type == "jpg":
        l_channel = tf.image.decode_jpeg(image, channels=3)
    if type == "bmp":
        l_channel = tf.image.decode_bmp(image, channels=3)
    height = image_size[0]
    width = image_size[1]
    l_channel = tf.image.resize_images(l_channel, [height, width])
    l_channel = tf.cast(l_channel, tf.float64) / 255.0
    l_channel = input_data.rgb_to_lab(l_channel)
    l_channel = tf.cast(l_channel, tf.float32)
    ab_channel = l_channel[:, :, 1:]
    #ab_channel = (l_channel[:, :, 1:] + 128) / 255.0
    l_channel = l_channel[:, :, 0] / 100.0
    l_channel = tf.reshape(l_channel, [height, width, 1])
    ab_channel = tf.reshape(ab_channel, [1, height, width, 2])
    l_channel = tf.reshape(l_channel, [1, height, width, 1])
    #lab_channel = tf.reshape(lab_channel, [1, height, width, 3])
    return l_channel, ab_channel
Exemple #3
0
def run_training():
    train_dir = "D:\\themeProject\\Database\\test"
    index_dir = "D:\\themeProject\\Database\\ColorTheme7"

    logs_dir = "D:\\themeProject\\logs\\themeRecommend10"
    result_dir = "themeResult/themeRecommend10/"

    # 获取输入
    image_list = input_data.get_themeRecommend_list(train_dir, index_dir)
    #train batch[BATCH_SIZE, 224, 224, 3], index batch[BATCH_SIZE, 1, 7, 3]
    train_rgb_batch, index_rgb_batch = input_data.get_themeRecommend_batch(image_list, BATCH_SIZE, CAPACITY)

    train_lab_batch = tf.cast(input_data.rgb_to_lab(train_rgb_batch), dtype = tf.float32)
    index_lab_batch = tf.cast(input_data.rgb_to_lab(index_rgb_batch), dtype=tf.float32)

    #normalize
    train_l_batch = train_lab_batch[:, :, :, 0:1] / 100
    train_ab_batch = (train_lab_batch[:, :, :, 1:] + 128) / 255
    index_l_batch = index_lab_batch[:, :, :, 0:1] / 100
    index_ab_batch = (index_lab_batch[:, :, :, 1:] + 128) / 255
    train_n_batch = tf.concat([train_l_batch, train_ab_batch], 3)
    index_n_batch = tf.concat([index_l_batch, index_ab_batch], 3)

    index_n_batch = tf.reshape(index_n_batch, [BATCH_SIZE, 1, -1])
    #out_batch [BATCH_SIZE, 1, 21]
    out_batch = model.built_network(train_l_batch)
    print(out_batch)
    sess = tf.Session()

    global_step = tf.train.get_or_create_global_step(sess.graph)
    train_loss = model.whole_loss(out_batch, index_n_batch)
    train_op = model.training(train_loss, global_step)

    out_lab_batch = tf.cast(tf.reshape(out_batch, [BATCH_SIZE, 1, 7, 3]), tf.float64)
    index_n_batch = tf.cast(tf.reshape(index_n_batch, [BATCH_SIZE, 1, 7, 3]), tf.float64)
    train_n_batch = tf.cast(train_n_batch, tf.float64)
    index_lab_batch = tf.cast(index_lab_batch, tf.float64)
    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    saver = tf.train.Saver(max_to_keep=20)

    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                break
            _, tra_loss = sess.run([train_op, train_loss])
            if isnan(tra_loss):
                print('Loss is NaN.')
                checkpoint_path = os.path.join(logs_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)
                exit(-1)
            if step % 100 == 0:     # 及时记录MSE的变化
                merged = sess.run(summary_op)
                train_writer.add_summary(merged, step)
                print("Step: %d,  loss: %g" % (step, tra_loss))
            if step % (MAX_STEP/20) == 0 or step == MAX_STEP-1:     # 保存20个检查点
                checkpoint_path = os.path.join(logs_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)

            if step % 100 == 0:
                train_lab, index_lab, out_lab, index = sess.run([train_n_batch, index_n_batch, out_lab_batch, index_lab_batch])
                train_lab = train_lab[0]
                index_lab = index_lab[0]
                out_lab = out_lab[0]
                index = index[0]


                train_lab[:,:,0:1] = train_lab[:,:,0:1] * 100
                train_lab[:, :, 1:] = train_lab[:,:,1:] * 255 - 128
                index_lab[:, :, 0:1] = index_lab[:, :, 0:1] * 100
                index_lab[:, :, 1:] = index_lab[:, :, 1:] * 255 - 128
                out_lab[:, :, 0:1] = out_lab[:, :, 0:1] * 100
                out_lab[:, :, 1:] = out_lab[:, :, 1:] * 255 - 128
                print(out_lab)


                train_rgb = color.lab2rgb(train_lab)
                index_rgb = color.lab2rgb(index_lab)
                out_rgb = color.lab2rgb(out_lab)


                plt.subplot(1, 3, 1), plt.imshow(train_rgb)
                plt.subplot(1, 3, 2), plt.imshow(index_rgb)
                plt.subplot(1, 3, 3), plt.imshow(out_rgb)
                if not os.path.exists(result_dir):
                    os.makedirs(result_dir)
                plt.savefig(result_dir + str(step) + "_image.png")
                #plt.show()

    except tf.errors.OutOfRangeError:
        print("Done.")
    finally:
        coord.request_stop()

    # 等待线程结束
    coord.join(threads=threads)
    sess.close()
Exemple #4
0
def run_training():
    train_dir = "G:\\Database\\ColoredData\\new_colorimage1"
    theme_index_dir = "G:\\Database\\ColoredData\\ColorMap5_image4"
    image_index_dir = "G:\\Database\\ColoredData\\new_colorimage4"
    theme_dir = "G:\\Database\\ColoredData\\colorImages4_5theme"
    #theme_mask_dir = "G:\\Database\\ColorThemeMask5"

    logs_dir = "F:\\Project_Yang\\Code\\mainProject\\logs\\log_global\\image loss 0.2"
    result_dir = "results/global/"

    # 获取输入
    image_list = input_data.get_themeInput_list(train_dir, theme_dir,
                                                theme_index_dir,
                                                image_index_dir)
    train_batch, theme_batch, theme_index_batch, theme_mask_batch, image_index_batch = input_data.get_themeObj_batch(
        image_list, BATCH_SIZE, CAPACITY)

    #rgb_to_lab
    train_batch = tf.cast(train_batch, tf.float64)
    image_index_batch = tf.cast(image_index_batch, tf.float64)
    theme_batch = tf.cast(theme_batch, tf.float64)
    theme_index_batch = tf.cast(theme_index_batch, tf.float64)

    train_lab_batch = tf.cast(input_data.rgb_to_lab(train_batch), tf.float32)
    theme_lab_batch = tf.cast(input_data.rgb_to_lab(theme_batch), tf.float32)
    index_lab_batch = tf.cast(input_data.rgb_to_lab(image_index_batch),
                              tf.float32)
    themeIndex_lab_batch = tf.cast(input_data.rgb_to_lab(theme_index_batch),
                                   tf.float32)

    #do + - * / before normalization

    #normalization
    image_l_batch = tf.reshape(train_lab_batch[:, :, :, 0] / 100,
                               [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1])
    image_ab_batch = (train_lab_batch[:, :, :, 1:] + 128) / 255
    theme_ab_batch = (theme_lab_batch[:, :, :, 1:] + 128) / 255
    theme_l_batch = (theme_lab_batch[:, :, :, 0:1] + 128) / 255
    index_ab_batch = (index_lab_batch[:, :, :, 1:] + 128) / 255
    themeIndex_ab_batch = (themeIndex_lab_batch[:, :, :, 1:] + 128) / 255

    #input batches
    theme_input = tf.concat([theme_ab_batch, theme_mask_batch], 3)

    #concat image_ab and sparse_ab as input
    out_ab_batch = model.new_built_network(image_ab_batch, theme_input)

    image_l_batch = tf.cast(image_l_batch, tf.float64)
    theme_lab_batch = tf.cast(theme_lab_batch, tf.float64)

    sess = tf.Session()

    global_step = tf.train.get_or_create_global_step(sess.graph)
    train_loss, index_loss, color_loss, image_loss = model.whole_loss(
        out_ab_batch, index_ab_batch, themeIndex_ab_batch, image_ab_batch)
    train_rmse, train_psnr = model.get_PSNR(out_ab_batch, index_ab_batch)
    train_op = model.training(train_loss, global_step)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    saver = tf.train.Saver(max_to_keep=20)

    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                break

            _, tra_loss, ind_loss, col_loss, img_loss = sess.run(
                [train_op, train_loss, index_loss, color_loss, image_loss])
            tra_rmse, tra_psnr = sess.run([train_rmse, train_psnr])

            if isnan(tra_loss):
                print('Loss is NaN.')
                checkpoint_path = os.path.join(logs_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)
                exit(-1)
            if step % 100 == 0:  # 及时记录MSE的变化
                merged = sess.run(summary_op)
                train_writer.add_summary(merged, step)
                print(
                    "Step: %d,  loss: %g,  index_loss: %g,  color_loss: %g, image_loss: %g,  rmse: %g,  psnr: %g"
                    % (step, tra_loss, ind_loss, col_loss, img_loss, tra_rmse,
                       tra_psnr))
            if step % (MAX_STEP / 20) == 0 or step == MAX_STEP - 1:  # 保存20个检查点
                checkpoint_path = os.path.join(logs_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)

            if step % 2000 == 0:
                l, ab, ab_index, ab_out, theme_lab, colored = sess.run([
                    image_l_batch, image_ab_batch, index_ab_batch,
                    out_ab_batch, theme_lab_batch, themeIndex_ab_batch
                ])
                l = l[0] * 100
                ab = ab[0] * 255 - 128
                ab_index = ab_index[0] * 255 - 128
                ab_out = ab_out[0] * 255 - 128
                colored = colored[0] * 255 - 128

                img_in = np.concatenate([l, ab], 2)
                img_in = color.lab2rgb(img_in)
                img_out = np.concatenate([l, ab_out], 2)
                img_out = color.lab2rgb(img_out)
                img_index = np.concatenate([l, ab_index], 2)
                img_index = color.lab2rgb(img_index)
                img_colored = np.concatenate([l, colored], 2)
                img_colored = color.lab2rgb(img_colored)
                theme = color.lab2rgb(theme_lab[0])

                plt.subplot(5, 4, 1), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(5, 4, 2), plt.imshow(ab[:, :, 0], 'gray')
                plt.subplot(5, 4, 3), plt.imshow(ab[:, :, 1], 'gray')
                plt.subplot(5, 4, 4), plt.imshow(img_in)

                plt.subplot(5, 4, 5), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(5, 4, 6), plt.imshow(ab_out[:, :, 0], 'gray')
                plt.subplot(5, 4, 7), plt.imshow(ab_out[:, :, 1], 'gray')
                plt.subplot(5, 4, 8), plt.imshow(img_out)

                plt.subplot(5, 4, 9), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(5, 4, 10), plt.imshow(ab_index[:, :, 0], 'gray')
                plt.subplot(5, 4, 11), plt.imshow(ab_index[:, :, 1], 'gray')
                plt.subplot(5, 4, 12), plt.imshow(img_index)

                plt.subplot(5, 4, 13), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(5, 4, 14), plt.imshow(colored[:, :, 0], 'gray')
                plt.subplot(5, 4, 15), plt.imshow(colored[:, :, 1], 'gray')
                plt.subplot(5, 4, 16), plt.imshow(img_colored)

                plt.subplot(5, 4, 17), plt.imshow(theme)
                plt.savefig(result_dir + str(step) + "_image.png")
                plt.show()

                plt.figure(figsize=(8, 8))
                axes1 = plt.subplot(231)
                axes1.scatter(ab[:, :, 0],
                              ab[:, :, 1],
                              alpha=0.5,
                              edgecolor='white',
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('input images')

                axes2 = plt.subplot(232)
                axes2.scatter(ab_out[:, :, 0],
                              ab_out[:, :, 1],
                              alpha=0.5,
                              edgecolor='white',
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('output images')

                axes3 = plt.subplot(233)
                axes3.scatter(ab_index[:, :, 0],
                              ab_index[:, :, 1],
                              alpha=0.5,
                              edgecolor='white',
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('index images')

                axes4 = plt.subplot(234)
                axes4.scatter(colored[:, :, 0],
                              colored[:, :, 1],
                              alpha=0.5,
                              edgecolor="white",
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('colored images')

                axes5 = plt.subplot(235)
                part1 = axes5.scatter(ab[:, :, 0],
                                      ab[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_in',
                                      s=8)
                part2 = axes5.scatter(ab_index[:, :, 0],
                                      ab_index[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_index',
                                      c='g',
                                      s=8)
                part3 = axes5.scatter(colored[:, :, 0],
                                      colored[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_out',
                                      c='y',
                                      s=8)
                part4 = axes5.scatter(ab_out[:, :, 0],
                                      ab_out[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_out',
                                      c='r',
                                      s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                axes4.legend((part1, part2, part3, part4),
                             ('input', 'index', 'colored', 'output'))
                plt.savefig(result_dir + str(step) + "_scatter.png")
                plt.show()

    except tf.errors.OutOfRangeError:
        print("Done.")
    finally:
        coord.request_stop()

    # 等待线程结束
    coord.join(threads=threads)
    sess.close()
def run_training():
    train_dir = "F:\\Database\\ColoredData\\new_colorimage1"
    index_dir = "F:\\Database\\ColoredData\\new_colorimage4"
    mask_dir = "F:\\Database\\ColoredData\\newSparseMask"
    sparse_dir = "F:\\Database\\ColoredData\\newSparse"
    logs_dir = "E:\\Project_Yang\\Code\\logs\\local\\local3"
    result_dir = "results/local/local33/"

    # 获取输入
    image_list = input_data.get_local_list(train_dir, sparse_dir, mask_dir,
                                           index_dir)
    train_rgb_batch, sparse_rgb_batch, mask_2channels_batch, index_rgb_batch\
        = input_data.get_local_batch(image_list, BATCH_SIZE, CAPACITY)

    train_lab_batch = tf.cast(input_data.rgb_to_lab(train_rgb_batch),
                              tf.float32)
    sparse_lab_batch = tf.cast(input_data.rgb_to_lab(sparse_rgb_batch),
                               tf.float32)
    index_lab_batch = tf.cast(input_data.rgb_to_lab(index_rgb_batch),
                              tf.float32)
    mask_2channels_batch = tf.cast(mask_2channels_batch, tf.float32)

    #do '+ - * /' before normalization
    train_l_batch = train_lab_batch[:, :, :, 0:1] / 100
    train_ab_batch = (train_lab_batch[:, :, :, 1:] + 128) / 255
    sparse_l_batch = sparse_lab_batch[:, :, :, 0:1] / 100
    sparse_ab_batch = (sparse_lab_batch[:, :, :, 1:] + 128) / 255
    index_l_batch = index_lab_batch[:, :, :, 0:1] / 100
    index_ab_batch = (index_lab_batch[:, :, :, 1:] + 128) / 255
    sparse_input = tf.concat(
        [sparse_ab_batch, mask_2channels_batch[:, :, :, 0:1]], 3)

    #concat image_ab and sparse_ab as input
    out_ab_batch = model.built_network(train_ab_batch, sparse_input)
    sess = tf.Session()

    global_step = tf.train.get_or_create_global_step(sess.graph)
    train_loss, separate_loss = model.whole_loss(out_ab_batch, index_ab_batch,
                                                 train_ab_batch,
                                                 mask_2channels_batch)
    train_rmse, train_psnr = model.get_PSNR(out_ab_batch, index_ab_batch)
    train_op = model.training(train_loss, global_step)

    train_l_batch = tf.cast(train_l_batch, tf.float64)
    index_l_batch = tf.cast(index_l_batch, tf.float64)
    #lab_batch = tf.cast(lab_batch, tf.float64)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    saver = tf.train.Saver(max_to_keep=20)

    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                break

            _, tra_loss = sess.run([train_op, train_loss])
            tra_rmse, tra_psnr = sess.run([train_rmse, train_psnr])

            if isnan(tra_loss):
                print('Loss is NaN.')
                checkpoint_path = os.path.join(logs_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)
                exit(-1)
            if step % 100 == 0:  # 及时记录MSE的变化
                merged = sess.run(summary_op)
                train_writer.add_summary(merged, step)
                print("Step: %d,    loss: %g   RMSE: %g,   PSNR: %g" %
                      (step, tra_loss, tra_rmse, tra_psnr))
            if step % (MAX_STEP / 20) == 0 or step == MAX_STEP - 1:  # 保存20个检查点
                checkpoint_path = os.path.join(logs_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)

            if step % 1000 == 0:
                l, ab, ab_index, ab_out = sess.run([
                    train_l_batch, train_ab_batch, index_ab_batch, out_ab_batch
                ])
                l = l[0]
                ab = ab[0]
                ab_index = ab_index[0]
                ab_out = ab_out[0]

                l = l * 100
                ab = ab * 255 - 128
                ab_out = ab_out * 255 - 128
                ab_index = ab_index * 255 - 128

                img_in = np.concatenate([l, ab], 2)
                img_in = color.lab2rgb(img_in)
                img_out = np.concatenate([l, ab_out], 2)
                img_out = color.lab2rgb(img_out)
                img_index = np.concatenate([l, ab_index], 2)
                img_index = color.lab2rgb(img_index)

                plt.subplot(3, 4, 1), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(3, 4, 2), plt.imshow(ab[:, :, 0], 'gray')
                plt.subplot(3, 4, 3), plt.imshow(ab[:, :, 1], 'gray')
                plt.subplot(3, 4, 4), plt.imshow(img_in)

                plt.subplot(3, 4, 5), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(3, 4, 6), plt.imshow(ab_out[:, :, 0], 'gray')
                plt.subplot(3, 4, 7), plt.imshow(ab_out[:, :, 1], 'gray')
                plt.subplot(3, 4, 8), plt.imshow(img_out)

                plt.subplot(3, 4, 9), plt.imshow(l[:, :, 0], 'gray')
                plt.subplot(3, 4, 10), plt.imshow(ab_index[:, :, 0], 'gray')
                plt.subplot(3, 4, 11), plt.imshow(ab_index[:, :, 1], 'gray')
                plt.subplot(3, 4, 12), plt.imshow(img_index)
                plt.savefig(result_dir + str(step) + "_image.png")
                plt.show()

                plt.figure(figsize=(8, 8))
                axes1 = plt.subplot(221)
                axes1.scatter(ab[:, :, 0],
                              ab[:, :, 1],
                              alpha=0.5,
                              edgecolor='white',
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('input images')

                axes2 = plt.subplot(222)
                axes2.scatter(ab_out[:, :, 0],
                              ab_out[:, :, 1],
                              alpha=0.5,
                              edgecolor='white',
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('output images')

                axes3 = plt.subplot(223)
                axes3.scatter(ab_index[:, :, 0],
                              ab_index[:, :, 1],
                              alpha=0.5,
                              edgecolor='white',
                              s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                plt.title('index images')

                axes4 = plt.subplot(224)
                part1 = axes4.scatter(ab[:, :, 0],
                                      ab[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_in',
                                      s=8)
                part2 = axes4.scatter(ab_index[:, :, 0],
                                      ab_index[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_index',
                                      c='g',
                                      s=8)
                part3 = axes4.scatter(ab_out[:, :, 0],
                                      ab_out[:, :, 1],
                                      alpha=0.5,
                                      edgecolor='white',
                                      label='image_out',
                                      c='r',
                                      s=8)
                plt.xlabel('a')
                plt.ylabel('b')
                if not os.path.exists(result_dir):
                    os.makedirs(result_dir)
                axes4.legend((part1, part2, part3),
                             ('input', 'index', 'output'))
                plt.savefig(result_dir + str(step) + "_scatter.png")
                plt.show()

    except tf.errors.OutOfRangeError:
        print("Done.")
    finally:
        coord.request_stop()

    # 等待线程结束
    coord.join(threads=threads)
    sess.close()