def get_theme_channels(theme_img, type): if type == "jpg": theme_img = tf.image.decode_jpeg(theme_img, channels=3) if type == "bmp": theme_img = tf.image.decode_bmp(theme_img, channels=3) theme_img = tf.image.resize_images(theme_img, [1, 5]) theme_img = tf.cast(theme_img, tf.float64) / 255.0 theme_img = input_data.rgb_to_lab(theme_img) theme_img = tf.cast(theme_img, tf.float32) theme_ab = (theme_img[:, :, 1:] + 128) / 255 theme_ab = tf.reshape(theme_ab, [1, 1, 5, 2]) return theme_ab
def get_lab_channel(image, image_size, type): if type == "jpg": l_channel = tf.image.decode_jpeg(image, channels=3) if type == "bmp": l_channel = tf.image.decode_bmp(image, channels=3) height = image_size[0] width = image_size[1] l_channel = tf.image.resize_images(l_channel, [height, width]) l_channel = tf.cast(l_channel, tf.float64) / 255.0 l_channel = input_data.rgb_to_lab(l_channel) l_channel = tf.cast(l_channel, tf.float32) ab_channel = l_channel[:, :, 1:] #ab_channel = (l_channel[:, :, 1:] + 128) / 255.0 l_channel = l_channel[:, :, 0] / 100.0 l_channel = tf.reshape(l_channel, [height, width, 1]) ab_channel = tf.reshape(ab_channel, [1, height, width, 2]) l_channel = tf.reshape(l_channel, [1, height, width, 1]) #lab_channel = tf.reshape(lab_channel, [1, height, width, 3]) return l_channel, ab_channel
def run_training(): train_dir = "D:\\themeProject\\Database\\test" index_dir = "D:\\themeProject\\Database\\ColorTheme7" logs_dir = "D:\\themeProject\\logs\\themeRecommend10" result_dir = "themeResult/themeRecommend10/" # 获取输入 image_list = input_data.get_themeRecommend_list(train_dir, index_dir) #train batch[BATCH_SIZE, 224, 224, 3], index batch[BATCH_SIZE, 1, 7, 3] train_rgb_batch, index_rgb_batch = input_data.get_themeRecommend_batch(image_list, BATCH_SIZE, CAPACITY) train_lab_batch = tf.cast(input_data.rgb_to_lab(train_rgb_batch), dtype = tf.float32) index_lab_batch = tf.cast(input_data.rgb_to_lab(index_rgb_batch), dtype=tf.float32) #normalize train_l_batch = train_lab_batch[:, :, :, 0:1] / 100 train_ab_batch = (train_lab_batch[:, :, :, 1:] + 128) / 255 index_l_batch = index_lab_batch[:, :, :, 0:1] / 100 index_ab_batch = (index_lab_batch[:, :, :, 1:] + 128) / 255 train_n_batch = tf.concat([train_l_batch, train_ab_batch], 3) index_n_batch = tf.concat([index_l_batch, index_ab_batch], 3) index_n_batch = tf.reshape(index_n_batch, [BATCH_SIZE, 1, -1]) #out_batch [BATCH_SIZE, 1, 21] out_batch = model.built_network(train_l_batch) print(out_batch) sess = tf.Session() global_step = tf.train.get_or_create_global_step(sess.graph) train_loss = model.whole_loss(out_batch, index_n_batch) train_op = model.training(train_loss, global_step) out_lab_batch = tf.cast(tf.reshape(out_batch, [BATCH_SIZE, 1, 7, 3]), tf.float64) index_n_batch = tf.cast(tf.reshape(index_n_batch, [BATCH_SIZE, 1, 7, 3]), tf.float64) train_n_batch = tf.cast(train_n_batch, tf.float64) index_lab_batch = tf.cast(index_lab_batch, tf.float64) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(logs_dir, sess.graph) saver = tf.train.Saver(max_to_keep=20) sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for step in range(MAX_STEP): if coord.should_stop(): break _, tra_loss = sess.run([train_op, train_loss]) if isnan(tra_loss): print('Loss is NaN.') checkpoint_path = os.path.join(logs_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step) exit(-1) if step % 100 == 0: # 及时记录MSE的变化 merged = sess.run(summary_op) train_writer.add_summary(merged, step) print("Step: %d, loss: %g" % (step, tra_loss)) if step % (MAX_STEP/20) == 0 or step == MAX_STEP-1: # 保存20个检查点 checkpoint_path = os.path.join(logs_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step) if step % 100 == 0: train_lab, index_lab, out_lab, index = sess.run([train_n_batch, index_n_batch, out_lab_batch, index_lab_batch]) train_lab = train_lab[0] index_lab = index_lab[0] out_lab = out_lab[0] index = index[0] train_lab[:,:,0:1] = train_lab[:,:,0:1] * 100 train_lab[:, :, 1:] = train_lab[:,:,1:] * 255 - 128 index_lab[:, :, 0:1] = index_lab[:, :, 0:1] * 100 index_lab[:, :, 1:] = index_lab[:, :, 1:] * 255 - 128 out_lab[:, :, 0:1] = out_lab[:, :, 0:1] * 100 out_lab[:, :, 1:] = out_lab[:, :, 1:] * 255 - 128 print(out_lab) train_rgb = color.lab2rgb(train_lab) index_rgb = color.lab2rgb(index_lab) out_rgb = color.lab2rgb(out_lab) plt.subplot(1, 3, 1), plt.imshow(train_rgb) plt.subplot(1, 3, 2), plt.imshow(index_rgb) plt.subplot(1, 3, 3), plt.imshow(out_rgb) if not os.path.exists(result_dir): os.makedirs(result_dir) plt.savefig(result_dir + str(step) + "_image.png") #plt.show() except tf.errors.OutOfRangeError: print("Done.") finally: coord.request_stop() # 等待线程结束 coord.join(threads=threads) sess.close()
def run_training(): train_dir = "G:\\Database\\ColoredData\\new_colorimage1" theme_index_dir = "G:\\Database\\ColoredData\\ColorMap5_image4" image_index_dir = "G:\\Database\\ColoredData\\new_colorimage4" theme_dir = "G:\\Database\\ColoredData\\colorImages4_5theme" #theme_mask_dir = "G:\\Database\\ColorThemeMask5" logs_dir = "F:\\Project_Yang\\Code\\mainProject\\logs\\log_global\\image loss 0.2" result_dir = "results/global/" # 获取输入 image_list = input_data.get_themeInput_list(train_dir, theme_dir, theme_index_dir, image_index_dir) train_batch, theme_batch, theme_index_batch, theme_mask_batch, image_index_batch = input_data.get_themeObj_batch( image_list, BATCH_SIZE, CAPACITY) #rgb_to_lab train_batch = tf.cast(train_batch, tf.float64) image_index_batch = tf.cast(image_index_batch, tf.float64) theme_batch = tf.cast(theme_batch, tf.float64) theme_index_batch = tf.cast(theme_index_batch, tf.float64) train_lab_batch = tf.cast(input_data.rgb_to_lab(train_batch), tf.float32) theme_lab_batch = tf.cast(input_data.rgb_to_lab(theme_batch), tf.float32) index_lab_batch = tf.cast(input_data.rgb_to_lab(image_index_batch), tf.float32) themeIndex_lab_batch = tf.cast(input_data.rgb_to_lab(theme_index_batch), tf.float32) #do + - * / before normalization #normalization image_l_batch = tf.reshape(train_lab_batch[:, :, :, 0] / 100, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1]) image_ab_batch = (train_lab_batch[:, :, :, 1:] + 128) / 255 theme_ab_batch = (theme_lab_batch[:, :, :, 1:] + 128) / 255 theme_l_batch = (theme_lab_batch[:, :, :, 0:1] + 128) / 255 index_ab_batch = (index_lab_batch[:, :, :, 1:] + 128) / 255 themeIndex_ab_batch = (themeIndex_lab_batch[:, :, :, 1:] + 128) / 255 #input batches theme_input = tf.concat([theme_ab_batch, theme_mask_batch], 3) #concat image_ab and sparse_ab as input out_ab_batch = model.new_built_network(image_ab_batch, theme_input) image_l_batch = tf.cast(image_l_batch, tf.float64) theme_lab_batch = tf.cast(theme_lab_batch, tf.float64) sess = tf.Session() global_step = tf.train.get_or_create_global_step(sess.graph) train_loss, index_loss, color_loss, image_loss = model.whole_loss( out_ab_batch, index_ab_batch, themeIndex_ab_batch, image_ab_batch) train_rmse, train_psnr = model.get_PSNR(out_ab_batch, index_ab_batch) train_op = model.training(train_loss, global_step) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(logs_dir, sess.graph) saver = tf.train.Saver(max_to_keep=20) sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for step in range(MAX_STEP): if coord.should_stop(): break _, tra_loss, ind_loss, col_loss, img_loss = sess.run( [train_op, train_loss, index_loss, color_loss, image_loss]) tra_rmse, tra_psnr = sess.run([train_rmse, train_psnr]) if isnan(tra_loss): print('Loss is NaN.') checkpoint_path = os.path.join(logs_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step) exit(-1) if step % 100 == 0: # 及时记录MSE的变化 merged = sess.run(summary_op) train_writer.add_summary(merged, step) print( "Step: %d, loss: %g, index_loss: %g, color_loss: %g, image_loss: %g, rmse: %g, psnr: %g" % (step, tra_loss, ind_loss, col_loss, img_loss, tra_rmse, tra_psnr)) if step % (MAX_STEP / 20) == 0 or step == MAX_STEP - 1: # 保存20个检查点 checkpoint_path = os.path.join(logs_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step) if step % 2000 == 0: l, ab, ab_index, ab_out, theme_lab, colored = sess.run([ image_l_batch, image_ab_batch, index_ab_batch, out_ab_batch, theme_lab_batch, themeIndex_ab_batch ]) l = l[0] * 100 ab = ab[0] * 255 - 128 ab_index = ab_index[0] * 255 - 128 ab_out = ab_out[0] * 255 - 128 colored = colored[0] * 255 - 128 img_in = np.concatenate([l, ab], 2) img_in = color.lab2rgb(img_in) img_out = np.concatenate([l, ab_out], 2) img_out = color.lab2rgb(img_out) img_index = np.concatenate([l, ab_index], 2) img_index = color.lab2rgb(img_index) img_colored = np.concatenate([l, colored], 2) img_colored = color.lab2rgb(img_colored) theme = color.lab2rgb(theme_lab[0]) plt.subplot(5, 4, 1), plt.imshow(l[:, :, 0], 'gray') plt.subplot(5, 4, 2), plt.imshow(ab[:, :, 0], 'gray') plt.subplot(5, 4, 3), plt.imshow(ab[:, :, 1], 'gray') plt.subplot(5, 4, 4), plt.imshow(img_in) plt.subplot(5, 4, 5), plt.imshow(l[:, :, 0], 'gray') plt.subplot(5, 4, 6), plt.imshow(ab_out[:, :, 0], 'gray') plt.subplot(5, 4, 7), plt.imshow(ab_out[:, :, 1], 'gray') plt.subplot(5, 4, 8), plt.imshow(img_out) plt.subplot(5, 4, 9), plt.imshow(l[:, :, 0], 'gray') plt.subplot(5, 4, 10), plt.imshow(ab_index[:, :, 0], 'gray') plt.subplot(5, 4, 11), plt.imshow(ab_index[:, :, 1], 'gray') plt.subplot(5, 4, 12), plt.imshow(img_index) plt.subplot(5, 4, 13), plt.imshow(l[:, :, 0], 'gray') plt.subplot(5, 4, 14), plt.imshow(colored[:, :, 0], 'gray') plt.subplot(5, 4, 15), plt.imshow(colored[:, :, 1], 'gray') plt.subplot(5, 4, 16), plt.imshow(img_colored) plt.subplot(5, 4, 17), plt.imshow(theme) plt.savefig(result_dir + str(step) + "_image.png") plt.show() plt.figure(figsize=(8, 8)) axes1 = plt.subplot(231) axes1.scatter(ab[:, :, 0], ab[:, :, 1], alpha=0.5, edgecolor='white', s=8) plt.xlabel('a') plt.ylabel('b') plt.title('input images') axes2 = plt.subplot(232) axes2.scatter(ab_out[:, :, 0], ab_out[:, :, 1], alpha=0.5, edgecolor='white', s=8) plt.xlabel('a') plt.ylabel('b') plt.title('output images') axes3 = plt.subplot(233) axes3.scatter(ab_index[:, :, 0], ab_index[:, :, 1], alpha=0.5, edgecolor='white', s=8) plt.xlabel('a') plt.ylabel('b') plt.title('index images') axes4 = plt.subplot(234) axes4.scatter(colored[:, :, 0], colored[:, :, 1], alpha=0.5, edgecolor="white", s=8) plt.xlabel('a') plt.ylabel('b') plt.title('colored images') axes5 = plt.subplot(235) part1 = axes5.scatter(ab[:, :, 0], ab[:, :, 1], alpha=0.5, edgecolor='white', label='image_in', s=8) part2 = axes5.scatter(ab_index[:, :, 0], ab_index[:, :, 1], alpha=0.5, edgecolor='white', label='image_index', c='g', s=8) part3 = axes5.scatter(colored[:, :, 0], colored[:, :, 1], alpha=0.5, edgecolor='white', label='image_out', c='y', s=8) part4 = axes5.scatter(ab_out[:, :, 0], ab_out[:, :, 1], alpha=0.5, edgecolor='white', label='image_out', c='r', s=8) plt.xlabel('a') plt.ylabel('b') axes4.legend((part1, part2, part3, part4), ('input', 'index', 'colored', 'output')) plt.savefig(result_dir + str(step) + "_scatter.png") plt.show() except tf.errors.OutOfRangeError: print("Done.") finally: coord.request_stop() # 等待线程结束 coord.join(threads=threads) sess.close()
def run_training(): train_dir = "F:\\Database\\ColoredData\\new_colorimage1" index_dir = "F:\\Database\\ColoredData\\new_colorimage4" mask_dir = "F:\\Database\\ColoredData\\newSparseMask" sparse_dir = "F:\\Database\\ColoredData\\newSparse" logs_dir = "E:\\Project_Yang\\Code\\logs\\local\\local3" result_dir = "results/local/local33/" # 获取输入 image_list = input_data.get_local_list(train_dir, sparse_dir, mask_dir, index_dir) train_rgb_batch, sparse_rgb_batch, mask_2channels_batch, index_rgb_batch\ = input_data.get_local_batch(image_list, BATCH_SIZE, CAPACITY) train_lab_batch = tf.cast(input_data.rgb_to_lab(train_rgb_batch), tf.float32) sparse_lab_batch = tf.cast(input_data.rgb_to_lab(sparse_rgb_batch), tf.float32) index_lab_batch = tf.cast(input_data.rgb_to_lab(index_rgb_batch), tf.float32) mask_2channels_batch = tf.cast(mask_2channels_batch, tf.float32) #do '+ - * /' before normalization train_l_batch = train_lab_batch[:, :, :, 0:1] / 100 train_ab_batch = (train_lab_batch[:, :, :, 1:] + 128) / 255 sparse_l_batch = sparse_lab_batch[:, :, :, 0:1] / 100 sparse_ab_batch = (sparse_lab_batch[:, :, :, 1:] + 128) / 255 index_l_batch = index_lab_batch[:, :, :, 0:1] / 100 index_ab_batch = (index_lab_batch[:, :, :, 1:] + 128) / 255 sparse_input = tf.concat( [sparse_ab_batch, mask_2channels_batch[:, :, :, 0:1]], 3) #concat image_ab and sparse_ab as input out_ab_batch = model.built_network(train_ab_batch, sparse_input) sess = tf.Session() global_step = tf.train.get_or_create_global_step(sess.graph) train_loss, separate_loss = model.whole_loss(out_ab_batch, index_ab_batch, train_ab_batch, mask_2channels_batch) train_rmse, train_psnr = model.get_PSNR(out_ab_batch, index_ab_batch) train_op = model.training(train_loss, global_step) train_l_batch = tf.cast(train_l_batch, tf.float64) index_l_batch = tf.cast(index_l_batch, tf.float64) #lab_batch = tf.cast(lab_batch, tf.float64) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(logs_dir, sess.graph) saver = tf.train.Saver(max_to_keep=20) sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for step in range(MAX_STEP): if coord.should_stop(): break _, tra_loss = sess.run([train_op, train_loss]) tra_rmse, tra_psnr = sess.run([train_rmse, train_psnr]) if isnan(tra_loss): print('Loss is NaN.') checkpoint_path = os.path.join(logs_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step) exit(-1) if step % 100 == 0: # 及时记录MSE的变化 merged = sess.run(summary_op) train_writer.add_summary(merged, step) print("Step: %d, loss: %g RMSE: %g, PSNR: %g" % (step, tra_loss, tra_rmse, tra_psnr)) if step % (MAX_STEP / 20) == 0 or step == MAX_STEP - 1: # 保存20个检查点 checkpoint_path = os.path.join(logs_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step) if step % 1000 == 0: l, ab, ab_index, ab_out = sess.run([ train_l_batch, train_ab_batch, index_ab_batch, out_ab_batch ]) l = l[0] ab = ab[0] ab_index = ab_index[0] ab_out = ab_out[0] l = l * 100 ab = ab * 255 - 128 ab_out = ab_out * 255 - 128 ab_index = ab_index * 255 - 128 img_in = np.concatenate([l, ab], 2) img_in = color.lab2rgb(img_in) img_out = np.concatenate([l, ab_out], 2) img_out = color.lab2rgb(img_out) img_index = np.concatenate([l, ab_index], 2) img_index = color.lab2rgb(img_index) plt.subplot(3, 4, 1), plt.imshow(l[:, :, 0], 'gray') plt.subplot(3, 4, 2), plt.imshow(ab[:, :, 0], 'gray') plt.subplot(3, 4, 3), plt.imshow(ab[:, :, 1], 'gray') plt.subplot(3, 4, 4), plt.imshow(img_in) plt.subplot(3, 4, 5), plt.imshow(l[:, :, 0], 'gray') plt.subplot(3, 4, 6), plt.imshow(ab_out[:, :, 0], 'gray') plt.subplot(3, 4, 7), plt.imshow(ab_out[:, :, 1], 'gray') plt.subplot(3, 4, 8), plt.imshow(img_out) plt.subplot(3, 4, 9), plt.imshow(l[:, :, 0], 'gray') plt.subplot(3, 4, 10), plt.imshow(ab_index[:, :, 0], 'gray') plt.subplot(3, 4, 11), plt.imshow(ab_index[:, :, 1], 'gray') plt.subplot(3, 4, 12), plt.imshow(img_index) plt.savefig(result_dir + str(step) + "_image.png") plt.show() plt.figure(figsize=(8, 8)) axes1 = plt.subplot(221) axes1.scatter(ab[:, :, 0], ab[:, :, 1], alpha=0.5, edgecolor='white', s=8) plt.xlabel('a') plt.ylabel('b') plt.title('input images') axes2 = plt.subplot(222) axes2.scatter(ab_out[:, :, 0], ab_out[:, :, 1], alpha=0.5, edgecolor='white', s=8) plt.xlabel('a') plt.ylabel('b') plt.title('output images') axes3 = plt.subplot(223) axes3.scatter(ab_index[:, :, 0], ab_index[:, :, 1], alpha=0.5, edgecolor='white', s=8) plt.xlabel('a') plt.ylabel('b') plt.title('index images') axes4 = plt.subplot(224) part1 = axes4.scatter(ab[:, :, 0], ab[:, :, 1], alpha=0.5, edgecolor='white', label='image_in', s=8) part2 = axes4.scatter(ab_index[:, :, 0], ab_index[:, :, 1], alpha=0.5, edgecolor='white', label='image_index', c='g', s=8) part3 = axes4.scatter(ab_out[:, :, 0], ab_out[:, :, 1], alpha=0.5, edgecolor='white', label='image_out', c='r', s=8) plt.xlabel('a') plt.ylabel('b') if not os.path.exists(result_dir): os.makedirs(result_dir) axes4.legend((part1, part2, part3), ('input', 'index', 'output')) plt.savefig(result_dir + str(step) + "_scatter.png") plt.show() except tf.errors.OutOfRangeError: print("Done.") finally: coord.request_stop() # 等待线程结束 coord.join(threads=threads) sess.close()